Commit 593dd7ad authored by letaoqin's avatar letaoqin
Browse files

clear some code

parent 6cb91035
...@@ -22,6 +22,12 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -22,6 +22,12 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>; using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
// if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
// t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
// {
// using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
// r = fused_moegemm_<t_>(s, a);
// }
// clang-format on // clang-format on
return r; return r;
} }
...@@ -8,7 +8,10 @@ ...@@ -8,7 +8,10 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -429,6 +429,14 @@ int main(int argc, char* argv[]) ...@@ -429,6 +429,14 @@ int main(int argc, char* argv[])
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
// no dynamic quant case // no dynamic quant case
// if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
// {
// return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
// arg_parser)
// ? 0
// : -2;
// }
// else
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32") if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
{ {
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>( return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
......
...@@ -193,10 +193,6 @@ struct FusedMoeGemmPipeline_General ...@@ -193,10 +193,6 @@ struct FusedMoeGemmPipeline_General
} }
// relu // relu
const auto activation = ck_tile::element_wise::Gelu{}; const auto activation = ck_tile::element_wise::Gelu{};
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size();
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
// activation(s_acc.get_thread_buffer()(i),s_acc.get_thread_buffer()[i]);
// });
tile_elementwise_inout(activation, s_acc, s_acc); tile_elementwise_inout(activation, s_acc, s_acc);
#if 0 #if 0
PrintMem(s_acc); PrintMem(s_acc);
...@@ -210,18 +206,7 @@ struct FusedMoeGemmPipeline_General ...@@ -210,18 +206,7 @@ struct FusedMoeGemmPipeline_General
{0, 0}); {0, 0});
// cast data to YDataType // cast data to YDataType
auto y_pre = cast_tile<YDataType>(s_acc); auto y_pre = cast_tile<YDataType>(s_acc);
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size(); #if 0
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
// //y_pre.get_thread_buffer()(i) = type_convert<YDataType>(s_acc.get_thread_buffer()[i]);
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// printf("soure value: %f to value: %f\n",
// s_acc.get_thread_buffer()[i],
// type_convert<float>(y_pre.get_thread_buffer()[i]));
// }
// });
#if 1
PrintMem(y_pre); PrintMem(y_pre);
#endif #endif
// save to lds // save to lds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment