Commit 727f201d authored by letaoqin's avatar letaoqin
Browse files

change save o to lds data type to float

parent 28252273
......@@ -256,14 +256,14 @@ struct FusedMoeGemmPipeline_General
constexpr auto w_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
s_acc.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
auto w_global_to_dram_window = make_tile_window(
w_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}),
w_window_.get_window_origin(),
w_dstr);
auto w = load_tile(w_global_to_dram_window);
float weight = type_convert<float>(w.get_thread_buffer()[0]);
s_acc.get_tile_distribution().get_static_tile_distribution_encoding(),
sequence<1>{}));
auto w_global_to_dram_window = make_tile_window(w_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}),
w_window_.get_window_origin(),
w_dstr);
auto w = load_tile(w_global_to_dram_window);
float weight = type_convert<float>(w.get_thread_buffer()[0]);
#if 0
constexpr index_t w_buffer_size = decltype(w)::get_thread_buffer_size();
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
......@@ -307,9 +307,10 @@ struct FusedMoeGemmPipeline_General
PrintMem(d,"D",0);
#endif
// add to LDS
CK_TILE_LDS_ADDR float* smem_3 = reinterpret_cast<CK_TILE_LDS_ADDR float*>(smem);
auto o_lds_view =
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::atomic_add>(
smem_0,
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::set>(
smem_3,
make_tuple(number<128>{}, number<32>{}),
make_tuple(32, 1),
number<8>{},
......@@ -333,12 +334,16 @@ struct FusedMoeGemmPipeline_General
move_tile_window(o_olds_win, {32, 0});
auto o1 = load_tile(o_olds_win);
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
o0.get_thread_buffer()(i) = type_convert<ODataType>(
type_convert<float>(o0.get_thread_buffer()[i]) +
type_convert<float>(o1.get_thread_buffer()[i]));
o0.get_thread_buffer()(i) =
type_convert<float>(type_convert<float>(o0.get_thread_buffer()[i]) +
type_convert<float>(o1.get_thread_buffer()[i]));
});
});
update_tile(o_window_, o0);
// tile_elementwise_inout([&weight](auto& x) { x = x *
// type_convert<float>(weight); },
// o0);
auto o = cast_tile<ODataType>(o0);
update_tile(o_window_, o);
// restore pos
move_tile_window(o_olds_win, {-32 * (BlockShape::Repeat_K1 - 1), 0});
}
......@@ -359,8 +364,8 @@ struct FusedMoeGemmPipeline_General
// move out window and save data
tile_elementwise_inout([&weight](auto& x) { x = x * type_convert<float>(weight); },
o_acc);
auto o = cast_tile<ODataType>(o_acc);
store_tile(o_alds_win, o);
// auto o = cast_tile<ODataType>(o_acc);
store_tile(o_alds_win, o_acc);
block_sync_lds();
save_o();
......@@ -375,10 +380,10 @@ struct FusedMoeGemmPipeline_General
gemm_1(o_acc, y, d);
// block_sync_lds();
tile_elementwise_inout(
[&weight](auto& x) { x = x * type_convert<float>(weight); }, o_acc);
auto o = cast_tile<ODataType>(o_acc);
store_tile(o_alds_win, o);
tile_elementwise_inout([&weight](auto& x) { x = x * type_convert<float>(weight); },
o_acc);
// auto o = cast_tile<ODataType>(o_acc);
store_tile(o_alds_win, o_acc);
block_sync_lds();
save_o();
// store_tile(o_window_, o);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment