Commit 7f4d6f08 authored by letaoqin's avatar letaoqin
Browse files

add MakeLdsBlockDesc_O

parent ce97a2af
......@@ -81,8 +81,7 @@ struct FusedMoeGemmPipeline_General
// shuffle C matrix
constexpr index_t smem_bridge = Policy::template GetSmemSize_Bridge<Problem>();
constexpr index_t smem_mat_o =
BlockShape::Block_N1 * BlockShape::Block_K1 * sizeof(float);
constexpr index_t smem_mat_o = BlockShape::Block_N1 * BlockShape::Block_K1 * sizeof(float);
return max(smem_mat_a + smem_mat_d, smem_bridge, smem_mat_o);
// return Policy::template GetSmemSize<Problem>();
......@@ -304,18 +303,15 @@ struct FusedMoeGemmPipeline_General
#endif
// add to LDS
CK_TILE_LDS_ADDR float* smem_o = reinterpret_cast<CK_TILE_LDS_ADDR float*>(smem);
auto o_lds_view =
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::set>(
smem_o,
make_tuple(number<128>{}, number<32>{}),
make_tuple(32, 1),
number<8>{},
number<1>{});
auto o_alds_win =
make_tile_window(o_lds_view, make_tuple(number<128>{}, number<32>{}), {0, 0});
auto o_olds_win =
make_tile_window(o_lds_view,
make_tuple(number<32>{}, number<32>{}),
auto o_lds_view = make_tensor_view<address_space_enum::lds>(
smem_o, Policy::template MakeLdsBlockDesc_O<Problem>());
auto o_alds_win = make_tile_window(
o_lds_view,
make_tuple(number<BlockShape::Block_K1>{}, number<BlockShape::Block_N1>{}),
{0, 0});
auto o_olds_win = make_tile_window(
o_lds_view,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0},
Policy::template MakeGlobalTileDistribution_O<Problem>());
......@@ -338,7 +334,8 @@ struct FusedMoeGemmPipeline_General
auto o = cast_tile<ODataType>(o0);
update_tile(o_window_, o);
// restore pos
move_tile_window(o_olds_win, {-32 * (BlockShape::Repeat_K1 - 1), 0});
move_tile_window(o_olds_win,
{-BlockShape::Block_M1 * (BlockShape::Repeat_K1 - 1), 0});
}
}
};
......
......@@ -134,7 +134,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
static_assert(M_rep <= 2);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
......@@ -152,6 +152,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
static_assert(M_rep <= 2);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
......@@ -354,6 +355,20 @@ struct FusedMoeGemmPipelineGeneralPolicy
return d_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_O()
{
constexpr index_t Block_N1 = Problem::BlockShape::Block_N1;
constexpr index_t Block_K1 = Problem::BlockShape::Block_N1;
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_K1>{}, number<Block_N1>{}),
make_tuple(number<Block_N1>{}, number<1>{}),
number<4>{},
number<1>{});
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment