"...composable_kernel_rocm.git" did not exist on "501a6b68eb647c76b0de00d03b93bc90f1e93111"
Commit 4be253ee authored by coderfeli's avatar coderfeli
Browse files

revert back to mul and silu

parent e15c6f2d
...@@ -23,12 +23,21 @@ auto get_elimit<ck_tile::bf16_t>() ...@@ -23,12 +23,21 @@ auto get_elimit<ck_tile::bf16_t>()
double atol = 1e-1; double atol = 1e-1;
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template<typename T> // template<typename T>
void fill(T * x, int len, T val) { // void cleartail(T * x, int len) {
for(int i = 0; i <len; i++){ // int len_32b = len * sizeof(T) / 4;
x[i] = val; // uint32_t *x_u32 = reinterpret_cast<uint32_t *>(x);
} // for(int i = 0; i <len_32b; i++){
} // x_u32[i] = x_u32[i] & 0xfff0fff0;
// }
// }
// template<typename T>
// void fill(T * x, int len, T val) {
// for(int i = 0; i <len; i++){
// x[i] = val;
// }
// }
// mfma_type, 0:32x32, 1:16x16 // mfma_type, 0:32x32, 1:16x16
// TODO: padding? // TODO: padding?
template <typename T> template <typename T>
...@@ -309,15 +318,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -309,15 +318,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else if(init == 3) else if(init == 3)
{ {
// fill((ADataType *)a_host.mData.data(), a_host.size(), ck_tile::type_convert<ADataType>(0.1f));
// fill((GDataType *)g_host.mData.data(), g_host.size(), ck_tile::type_convert<GDataType>(0.1f));
// fill((DDataType *)d_host.mData.data(), d_host.size(), ck_tile::type_convert<DDataType>(0.1f));
// fill((AScaleDataType *)sa_host.mData.data(), sa_host.size(), ck_tile::type_convert<AScaleDataType>(1.f));
// fill((GScaleDataType *)sg_host.mData.data(), sg_host.size(), ck_tile::type_convert<GScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((YSmoothScaleDataType *)sy_host.mData.data(), sy_host.size(), ck_tile::type_convert<YSmoothScaleDataType>(1.f));
// fill((TopkWeightDataType *)topk_weight_host.mData.data(), topk_weight_host.size(), ck_tile::type_convert<TopkWeightDataType>(1.f));
ck_tile::FillNormalDistribution<ADataType>{0.f, .1f, seed, true}(a_host); ck_tile::FillNormalDistribution<ADataType>{0.f, .1f, seed, true}(a_host);
ck_tile::FillNormalDistribution<GDataType>{0.f, .1f, seed, true}(g_host); ck_tile::FillNormalDistribution<GDataType>{0.f, .1f, seed, true}(g_host);
ck_tile::FillNormalDistribution<DDataType>{0.f, .1f, seed, true}(d_host); ck_tile::FillNormalDistribution<DDataType>{0.f, .1f, seed, true}(d_host);
...@@ -326,6 +326,30 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -326,6 +326,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host); ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host); ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host); ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
// cleartail((ADataType *)a_host.mData.data(), a_host.size());
// cleartail((GDataType *)g_host.mData.data(), g_host.size());
// cleartail((DDataType *)d_host.mData.data(), d_host.size());
// a_host.savetxt("a.txt");
// cleartail((AScaleDataType *)sa_host.mData.data(), sa_host.size());
// cleartail((GScaleDataType *)sg_host.mData.data(), sg_host.size());
// cleartail((DScaleDataType *)sd_host.mData.data(), sd_host.size());
// cleartail((DScaleDataType *)sd_host.mData.data(), sd_host.size());
// cleartail((YSmoothScaleDataType *)sy_host.mData.data(), sy_host.size());
// fill((ADataType *)a_host.mData.data(), a_host.size(), ck_tile::type_convert<ADataType>(.1f));
// fill((GDataType *)g_host.mData.data(), g_host.size(), ck_tile::type_convert<GDataType>(.1f));
// fill((DDataType *)d_host.mData.data(), d_host.size(), ck_tile::type_convert<DDataType>(.1f));
// fill((AScaleDataType *)sa_host.mData.data(), sa_host.size(), ck_tile::type_convert<AScaleDataType>(1.f));
// fill((GScaleDataType *)sg_host.mData.data(), sg_host.size(), ck_tile::type_convert<GScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((YSmoothScaleDataType *)sy_host.mData.data(), sy_host.size(), ck_tile::type_convert<YSmoothScaleDataType>(1.f));
// fill((TopkWeightDataType *)topk_weight_host.mData.data(), topk_weight_host.size(), ck_tile::type_convert<TopkWeightDataType>(1.f));
// cleartail((TopkWeightDataType *)topk_weight_host.mData.data(), topk_weight_host.size());
} }
// permute weight // permute weight
...@@ -484,7 +508,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -484,7 +508,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts, experts,
block_m); block_m);
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Silu>(
a_host, a_host,
g_host, g_host,
d_host, d_host,
......
...@@ -157,7 +157,7 @@ void reference_fused_moe( ...@@ -157,7 +157,7 @@ void reference_fused_moe(
{ {
AccDataType tmp; AccDataType tmp;
Activation{}(tmp, acc_0(0, i_n)); Activation{}(tmp, acc_0(0, i_n));
y(0, i_n) = tmp + acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
} }
} }
......
...@@ -380,10 +380,10 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -380,10 +380,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr auto REPEATS = BlockShape::Repeat_N0 * BlockShape::Repeat_M0; constexpr auto REPEATS = BlockShape::Repeat_N0 * BlockShape::Repeat_M0;
for(auto i = 0; i < REPEATS; i++) for(auto i = 0; i < REPEATS; i++)
{ {
acc_0.get_thread_buffer()[4 * i + 0] += acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 0]; acc_0.get_thread_buffer()[4 * i + 0] *= acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 0];
acc_0.get_thread_buffer()[4 * i + 1] += acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 1]; acc_0.get_thread_buffer()[4 * i + 1] *= acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 1];
acc_0.get_thread_buffer()[4 * i + 2] += acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 2]; acc_0.get_thread_buffer()[4 * i + 2] *= acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 2];
acc_0.get_thread_buffer()[4 * i + 3] += acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 3]; acc_0.get_thread_buffer()[4 * i + 3] *= acc_0_full.get_thread_buffer()[4 * (i + REPEATS) + 3];
} }
} }
block_sync_lds(); block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment