Commit 15e76415 authored by letaoqin's avatar letaoqin
Browse files

add padding to O

parent b885995c
...@@ -362,9 +362,14 @@ struct FusedMoeGemmGlKernel ...@@ -362,9 +362,14 @@ struct FusedMoeGemmGlKernel
make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
auto o_window_ = make_tile_window( auto o_padd_view_ = pad_tensor_view(
o_scatter_view_, o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}), make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
sequence<true, 0>{});
auto o_window_ = make_tile_window(
o_padd_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{idx_m0, 0}); {idx_m0, 0});
return o_window_; return o_window_;
}(); }();
......
...@@ -308,6 +308,27 @@ struct FusedMoeGemmPipeline_General ...@@ -308,6 +308,27 @@ struct FusedMoeGemmPipeline_General
Policy::template MakeGlobalTileDistribution_O<Problem>()); Policy::template MakeGlobalTileDistribution_O<Problem>());
ignore = o_alds_win; ignore = o_alds_win;
auto save_o = [&]() {
if(blockIdx.x == 0 && (blockIdx.y == 0 || blockIdx.y == 1) && blockIdx.z == 0)
{
if(threadIdx.x < 64)
{
auto o0 = load_tile(o_olds_win);
for(int step = 1; step < 4; step++)
{
move_tile_window(o_olds_win, {32, 0});
auto o1 = load_tile(o_olds_win);
for(int i = 0; i < 16; i++)
{
o0.get_thread_buffer()(i) = type_convert<ODataType>(
type_convert<float>(o0.get_thread_buffer()[i]) +
type_convert<float>(o1.get_thread_buffer()[i]));
}
}
update_tile(o_window_, o0);
}
}
};
constexpr index_t kN1 = BlockShape::Block_N1; constexpr index_t kN1 = BlockShape::Block_N1;
const index_t n1_loops = ck_tile::integer_divide_ceil(hidden_size, kN1); const index_t n1_loops = ck_tile::integer_divide_ceil(hidden_size, kN1);
index_t iCounter1 = n1_loops - 1; index_t iCounter1 = n1_loops - 1;
...@@ -336,30 +357,9 @@ struct FusedMoeGemmPipeline_General ...@@ -336,30 +357,9 @@ struct FusedMoeGemmPipeline_General
// tile_elementwise_inout( // tile_elementwise_inout(
// [&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc); // [&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc);
auto o = cast_tile<ODataType>(o_acc); auto o = cast_tile<ODataType>(o_acc);
#if 0
PrintMem(o, "O", 65);
#endif
store_tile(o_alds_win, o); store_tile(o_alds_win, o);
block_sync_lds(); block_sync_lds();
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) save_o();
{
if(threadIdx.x < 64)
{
auto o0 = load_tile(o_olds_win);
for(int step = 1; step < 4; step++)
{
move_tile_window(o_olds_win, {32, 0});
auto o1 = load_tile(o_olds_win);
for(int i = 0; i < 16; i++)
{
o0.get_thread_buffer()(i) = type_convert<ODataType>(
type_convert<float>(o0.get_thread_buffer()[i]) +
type_convert<float>(o1.get_thread_buffer()[i]));
}
}
update_tile(o_window_, o0);
}
}
// store_tile(o_window_, o); // store_tile(o_window_, o);
#if 0 #if 0
PrintMem(o,"O"); PrintMem(o,"O");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment