Commit 727f201d authored by letaoqin's avatar letaoqin
Browse files

change save o to lds data type to float

parent 28252273
...@@ -256,14 +256,14 @@ struct FusedMoeGemmPipeline_General ...@@ -256,14 +256,14 @@ struct FusedMoeGemmPipeline_General
constexpr auto w_dstr = constexpr auto w_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
s_acc.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{})); s_acc.get_tile_distribution().get_static_tile_distribution_encoding(),
auto w_global_to_dram_window = make_tile_window( sequence<1>{}));
w_window_.get_bottom_tensor_view(), auto w_global_to_dram_window = make_tile_window(w_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}), make_tuple(number<BlockShape::Block_M0>{}),
w_window_.get_window_origin(), w_window_.get_window_origin(),
w_dstr); w_dstr);
auto w = load_tile(w_global_to_dram_window); auto w = load_tile(w_global_to_dram_window);
float weight = type_convert<float>(w.get_thread_buffer()[0]); float weight = type_convert<float>(w.get_thread_buffer()[0]);
#if 0 #if 0
constexpr index_t w_buffer_size = decltype(w)::get_thread_buffer_size(); constexpr index_t w_buffer_size = decltype(w)::get_thread_buffer_size();
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
...@@ -307,9 +307,10 @@ struct FusedMoeGemmPipeline_General ...@@ -307,9 +307,10 @@ struct FusedMoeGemmPipeline_General
PrintMem(d,"D",0); PrintMem(d,"D",0);
#endif #endif
// add to LDS // add to LDS
CK_TILE_LDS_ADDR float* smem_3 = reinterpret_cast<CK_TILE_LDS_ADDR float*>(smem);
auto o_lds_view = auto o_lds_view =
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::atomic_add>( make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::set>(
smem_0, smem_3,
make_tuple(number<128>{}, number<32>{}), make_tuple(number<128>{}, number<32>{}),
make_tuple(32, 1), make_tuple(32, 1),
number<8>{}, number<8>{},
...@@ -333,12 +334,16 @@ struct FusedMoeGemmPipeline_General ...@@ -333,12 +334,16 @@ struct FusedMoeGemmPipeline_General
move_tile_window(o_olds_win, {32, 0}); move_tile_window(o_olds_win, {32, 0});
auto o1 = load_tile(o_olds_win); auto o1 = load_tile(o_olds_win);
static_for<0, thread_buffer_size, 1>{}([&](auto i) { static_for<0, thread_buffer_size, 1>{}([&](auto i) {
o0.get_thread_buffer()(i) = type_convert<ODataType>( o0.get_thread_buffer()(i) =
type_convert<float>(o0.get_thread_buffer()[i]) + type_convert<float>(type_convert<float>(o0.get_thread_buffer()[i]) +
type_convert<float>(o1.get_thread_buffer()[i])); type_convert<float>(o1.get_thread_buffer()[i]));
}); });
}); });
update_tile(o_window_, o0); // tile_elementwise_inout([&weight](auto& x) { x = x *
// type_convert<float>(weight); },
// o0);
auto o = cast_tile<ODataType>(o0);
update_tile(o_window_, o);
// restore pos // restore pos
move_tile_window(o_olds_win, {-32 * (BlockShape::Repeat_K1 - 1), 0}); move_tile_window(o_olds_win, {-32 * (BlockShape::Repeat_K1 - 1), 0});
} }
...@@ -359,8 +364,8 @@ struct FusedMoeGemmPipeline_General ...@@ -359,8 +364,8 @@ struct FusedMoeGemmPipeline_General
// move out window and save data // move out window and save data
tile_elementwise_inout([&weight](auto& x) { x = x * type_convert<float>(weight); }, tile_elementwise_inout([&weight](auto& x) { x = x * type_convert<float>(weight); },
o_acc); o_acc);
auto o = cast_tile<ODataType>(o_acc); // auto o = cast_tile<ODataType>(o_acc);
store_tile(o_alds_win, o); store_tile(o_alds_win, o_acc);
block_sync_lds(); block_sync_lds();
save_o(); save_o();
...@@ -375,10 +380,10 @@ struct FusedMoeGemmPipeline_General ...@@ -375,10 +380,10 @@ struct FusedMoeGemmPipeline_General
gemm_1(o_acc, y, d); gemm_1(o_acc, y, d);
// block_sync_lds(); // block_sync_lds();
tile_elementwise_inout( tile_elementwise_inout([&weight](auto& x) { x = x * type_convert<float>(weight); },
[&weight](auto& x) { x = x * type_convert<float>(weight); }, o_acc); o_acc);
auto o = cast_tile<ODataType>(o_acc); // auto o = cast_tile<ODataType>(o_acc);
store_tile(o_alds_win, o); store_tile(o_alds_win, o_acc);
block_sync_lds(); block_sync_lds();
save_o(); save_o();
// store_tile(o_window_, o); // store_tile(o_window_, o);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment