Commit eab497e8 authored by letaoqin's avatar letaoqin
Browse files

format

parent 1476d7bb
...@@ -208,7 +208,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -208,7 +208,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1}); ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
#if 0 #if 0
# if 1 #if 1
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host); ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host); ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host); ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
...@@ -217,7 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -217,7 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host); ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host); ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host); ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
# else #else
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host); ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host); ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host);
...@@ -226,7 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -226,7 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host); ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host); ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host); ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
# endif #endif
// permute weight // permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
...@@ -266,7 +266,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -266,7 +266,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
# if 0 #if 0
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
...@@ -319,7 +319,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -319,7 +319,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
return 1; return 1;
# endif #endif
#endif #endif
(void)balance; (void)balance;
......
...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
// clang-format on // clang-format on
......
...@@ -34,11 +34,14 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -34,11 +34,14 @@ struct fmoe_ // traits, ugly name, only used for internal
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>; using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>; using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token(block_m0, block_m1) static constexpr ck_tile::index_t BT_ =
BlockTIle_::at(ck_tile::number<0>{}); // block token(block_m0, block_m1)
static constexpr ck_tile::index_t BI_ = static constexpr ck_tile::index_t BI_ =
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate (block_n0, block_k1) BlockTIle_::at(ck_tile::number<1>{}); // block intermediate (block_n0, block_k1)
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden(block_k0) static constexpr ck_tile::index_t BH_ =
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down(block_n1) BlockTIle_::at(ck_tile::number<2>{}); // block hidden(block_k0)
static constexpr ck_tile::index_t BD_ =
BlockTIle_::at(ck_tile::number<3>{}); // block down(block_n1)
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>; using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -216,7 +216,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -216,7 +216,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host); ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f}(topk_weight_host); ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f}(topk_weight_host);
// permute weight // permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment