"vscode:/vscode.git/clone" did not exist on "9d5d6afac7c6ed4054a619b18a3da8a77cdb338f"
Commit 747dd16c authored by coderfeli's avatar coderfeli
Browse files

result correct but strange

parent 14099622
...@@ -11,19 +11,24 @@ ...@@ -11,19 +11,24 @@
template <typename DataType> template <typename DataType>
auto get_elimit() auto get_elimit()
{ {
double rtol = 1e-2; double rtol = 2e-2;
double atol = 1e-2; double atol = 2e-2;
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template <> template <>
auto get_elimit<ck_tile::bf16_t>() auto get_elimit<ck_tile::bf16_t>()
{ {
double rtol = 1e-2; double rtol = 1e-1;
double atol = 1e-2; double atol = 1e-1;
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template<typename T>
void fill(T * x, int len, T val) {
for(int i = 0; i <len; i++){
x[i] = val;
}
}
// mfma_type, 0:32x32, 1:16x16 // mfma_type, 0:32x32, 1:16x16
// TODO: padding? // TODO: padding?
template <typename T> template <typename T>
...@@ -133,9 +138,9 @@ auto create_args(int argc, char* argv[]) ...@@ -133,9 +138,9 @@ auto create_args(int argc, char* argv[])
.insert("tp", "8", "tensor parallel size") .insert("tp", "8", "tensor parallel size")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec_i", "bf16", "input precision") .insert("prec_i", "fp16", "input precision")
.insert("prec_w", "bf16", "weight precision") .insert("prec_w", "fp16", "weight precision")
.insert("prec_o", "bf16", "output precision") .insert("prec_o", "fp16", "output precision")
.insert("prec_st", "auto", "token scale data type. auto will set to fp32") .insert("prec_st", "auto", "token scale data type. auto will set to fp32")
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32") .insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32") .insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
...@@ -304,14 +309,46 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -304,14 +309,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else if(init == 3) else if(init == 3)
{ {
ck_tile::FillConstant<ADataType>{}(a_host); // ck_tile::FillConstant<ADataType>{}(a_host);
ck_tile::FillConstant<GDataType>{}(g_host); // ck_tile::FillStepRange<ADataType>{0.f, 16384.f, 1.f}(a_host);
ck_tile::FillConstant<DDataType>{}(d_host); // for (int i = 0 ; i < tokens; i++){
ck_tile::FillConstant<AScaleDataType>{}(sa_host); // for (int j = 0; j < hidden_size; j++) {
ck_tile::FillConstant<GScaleDataType>{}(sg_host); // a_host.mData[i * hidden_size + j] = ck_tile::type_convert<ADataType>(float(i+1) * 0.1 + float(i * j % 116) * 0.0012);
ck_tile::FillConstant<DScaleDataType>{}(sd_host); // }
ck_tile::FillConstant<YSmoothScaleDataType>{}(sy_host); // }
ck_tile::FillConstant<TopkWeightDataType>{}(topk_weight_host); ck_tile::FillUniformDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
ck_tile::FillUniformDistribution<GDataType>{0.f, 1.f, seed, true}(g_host);
ck_tile::FillUniformDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed, true}(
topk_weight_host);
// a_host.savetxt("a.txt");
// fill((ADataType *)a_host.mData.data(), a_host.size(), ck_tile::type_convert<ADataType>(0.1f));
// fill((GDataType *)g_host.mData.data(), g_host.size(), ck_tile::type_convert<GDataType>(0.1f));
// fill((DDataType *)d_host.mData.data(), d_host.size(), ck_tile::type_convert<DDataType>(0.1f));
// fill((AScaleDataType *)sa_host.mData.data(), sa_host.size(), ck_tile::type_convert<AScaleDataType>(1.f));
// fill((GScaleDataType *)sg_host.mData.data(), sg_host.size(), ck_tile::type_convert<GScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((YSmoothScaleDataType *)sy_host.mData.data(), sy_host.size(), ck_tile::type_convert<YSmoothScaleDataType>(1.f));
// fill((TopkWeightDataType *)topk_weight_host.mData.data(), topk_weight_host.size(), ck_tile::type_convert<TopkWeightDataType>(1.f));
// ck_tile::FillNormalDistribution<ADataType>{.1f, .1f, seed, true}(a_host);
// ck_tile::FillNormalDistribution<GDataType>{.1f, .1f, seed, true}(g_host);
// ck_tile::FillNormalDistribution<DDataType>{.1f, .1f, seed, true}(d_host);
// ck_tile::FillNormalDistribution<AScaleDataType>{1.f, 1.f, seed, true}(sa_host);
// ck_tile::FillNormalDistribution<GScaleDataType>{1.f, 1.f, seed, true}(sg_host);
// ck_tile::FillNormalDistribution<DScaleDataType>{1.f, 1.f, seed, true}(sd_host);
// ck_tile::FillNormalDistribution<YSmoothScaleDataType>{1.f, 1.f, seed, true}(sy_host);
// ck_tile::FillNormalDistribution<TopkWeightDataType>{1.f, 1.f, seed, true}(topk_weight_host);
// ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
// ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
// ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
// ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
} }
// permute weight // permute weight
...@@ -498,6 +535,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -498,6 +535,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
o_dev.savetxt("gpu-out.txt", "float"); o_dev.savetxt("gpu-out.txt", "float");
o_host.savetxt("ref.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>(); auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err( pass &= ck_tile::check_err(
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
...@@ -583,7 +621,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -583,7 +621,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Silu>(
a_host, a_host,
g_host, g_host,
d_host, d_host,
......
...@@ -339,7 +339,7 @@ struct FillStepRange ...@@ -339,7 +339,7 @@ struct FillStepRange
template <typename T> template <typename T>
struct FillConstant struct FillConstant
{ {
T value_{type_convert<T>(1.0f)}; T value_{type_convert<T>(1.f)};
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
......
...@@ -157,7 +157,7 @@ void reference_fused_moe( ...@@ -157,7 +157,7 @@ void reference_fused_moe(
{ {
AccDataType tmp; AccDataType tmp;
Activation{}(tmp, acc_0(0, i_n)); Activation{}(tmp, acc_0(0, i_n));
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul y(0, i_n) = tmp + acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
} }
} }
......
...@@ -201,7 +201,6 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_ ...@@ -201,7 +201,6 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
// [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), // [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
// [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), // [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
// [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), // [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes), [s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes), [s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0), [scale_0]"v"(s0),
...@@ -217,7 +216,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_ ...@@ -217,7 +216,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
[s_execflag_6]"s"(o_flags[number<6>{}]), [s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}]) [s_execflag_7]"s"(o_flags[number<7>{}])
: :
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "memory", "exec","m0","vcc", "scc", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
...@@ -275,14 +274,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_ ...@@ -275,14 +274,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
); );
#pragma clang diagnostic pop #pragma clang diagnostic pop
// clang-format on // clang-format on
if(1) { // if(threadIdx.x==0) {
printf("\n%d %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f\n", // printf("%d\n", threadIdx.x);
threadIdx.x, // }
type_convert<float>(v_debug.x), type_convert<float>(v_debug.y), if(threadIdx.x == 0) {
type_convert<float>(v_debug1.x), type_convert<float>(v_debug1.y), printf("%d \n", threadIdx.x);
type_convert<float>(v_debug2.x), type_convert<float>(v_debug2.y),
type_convert<float>(v_debug3.x), type_convert<float>(v_debug3.y));
} }
// }
// __syncthreads();
} }
}; };
...@@ -356,6 +355,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_ ...@@ -356,6 +355,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_
register float v_c29 asm("v93"); register float v_c29 asm("v93");
register float v_c30 asm("v94"); register float v_c30 asm("v94");
register float v_c31 asm("v95"); register float v_c31 asm("v95");
register fp16x2_t v_debug asm("v160");
register fp16x2_t v_debug1 asm("v161");
register fp16x2_t v_debug2 asm("v162");
register fp16x2_t v_debug3 asm("v163");
register fp16x2_t v_debug4 asm("v164");
register fp16x2_t v_debug5 asm("v165");
register fp16x2_t v_debug6 asm("v166");
register fp16x2_t v_debug7 asm("v167");
int32_t nan_hi = 0x7fff0000; int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff; int32_t nan_lo = 0x00007fff;
...@@ -424,7 +431,15 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_ ...@@ -424,7 +431,15 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_
[c28]"+v"(v_c28), [c28]"+v"(v_c28),
[c29]"+v"(v_c29), [c29]"+v"(v_c29),
[c30]"+v"(v_c30), [c30]"+v"(v_c30),
[c31]"+v"(v_c31) [c31]"+v"(v_c31),
[debug0]"+v"(v_debug),
[debug1]"+v"(v_debug1),
[debug2]"+v"(v_debug2),
[debug3]"+v"(v_debug3),
[debug4]"+v"(v_debug4),
[debug5]"+v"(v_debug5),
[debug6]"+v"(v_debug6),
[debug7]"+v"(v_debug7)
: :
[sld_a_base]"n"(0), [sld_a_base]"n"(0),
[shfl_base]"n"(0), [shfl_base]"n"(0),
...@@ -471,7 +486,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_ ...@@ -471,7 +486,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_
[s_execflag_6]"s"(o_flags[number<6>{}]), [s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}]) [s_execflag_7]"s"(o_flags[number<7>{}])
: :
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "memory", "exec","m0","vcc", "scc", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
...@@ -529,6 +544,9 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_ ...@@ -529,6 +544,9 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_
); );
#pragma clang diagnostic pop #pragma clang diagnostic pop
// clang-format on // clang-format on
if(threadIdx.x == 0) {
printf("%d \n", threadIdx.x);
}
} }
}; };
......
...@@ -70,10 +70,10 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -70,10 +70,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize(); // constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize(); // constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge = // constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0; // BlockShape::Block_M0 * BlockShape::Block_N0;
return 32768;//max(smem_0, max(smem_1, smem_bridge)); return 32768;//max(smem_0, max(smem_1, smem_bridge));
} }
...@@ -329,7 +329,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -329,7 +329,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
BlockShape::Block_K0, // tile offset for B matrix each unroll BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll BlockShape::Block_W0); // tile offset for B matrix each unroll
// for(auto i = 0; i < 8; i++) // for(auto i = 0; i < 16; i++)
// { // {
// if(threadIdx.x==0) { // if(threadIdx.x==0) {
// printf("%d, %.1f, %.1f, %.1f, %.1f\n",i, acc_0_full.get_thread_buffer()[4 * (i) + 0], acc_0_full.get_thread_buffer()[4 * (i) + 1], acc_0_full.get_thread_buffer()[4 * (i) + 2], acc_0_full.get_thread_buffer()[4 * (i) + 3]); // printf("%d, %.1f, %.1f, %.1f, %.1f\n",i, acc_0_full.get_thread_buffer()[4 * (i) + 0], acc_0_full.get_thread_buffer()[4 * (i) + 1], acc_0_full.get_thread_buffer()[4 * (i) + 2], acc_0_full.get_thread_buffer()[4 * (i) + 3]);
...@@ -366,12 +366,13 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -366,12 +366,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
} }
if (!IsGateOnly) { if (!IsGateOnly) {
for(auto i = 0; i < BlockShape::Repeat_N0 * BlockShape::Repeat_M0; i++) constexpr auto REPEATS = BlockShape::Repeat_N0 * BlockShape::Repeat_M0;
for(auto i = 0; i < REPEATS; i++)
{ {
acc_0.get_thread_buffer()[4 * i + 0] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 0]; acc_0.get_thread_buffer()[4 * i + 0] += acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 0];
acc_0.get_thread_buffer()[4 * i + 1] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 1]; acc_0.get_thread_buffer()[4 * i + 1] += acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 1];
acc_0.get_thread_buffer()[4 * i + 2] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 2]; acc_0.get_thread_buffer()[4 * i + 2] += acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 2];
acc_0.get_thread_buffer()[4 * i + 3] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 3]; acc_0.get_thread_buffer()[4 * i + 3] += acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 3];
} }
} }
block_sync_lds(); block_sync_lds();
......
...@@ -17,7 +17,7 @@ fi ...@@ -17,7 +17,7 @@ fi
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm/ \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -g -v --save-temps -Wno-gnu-line-marker " \ -D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker " \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment