Commit 7f4d6f08 authored by letaoqin's avatar letaoqin
Browse files

add MakeLdsBlockDesc_O

parent ce97a2af
...@@ -81,8 +81,7 @@ struct FusedMoeGemmPipeline_General ...@@ -81,8 +81,7 @@ struct FusedMoeGemmPipeline_General
// shuffle C matrix // shuffle C matrix
constexpr index_t smem_bridge = Policy::template GetSmemSize_Bridge<Problem>(); constexpr index_t smem_bridge = Policy::template GetSmemSize_Bridge<Problem>();
constexpr index_t smem_mat_o = constexpr index_t smem_mat_o = BlockShape::Block_N1 * BlockShape::Block_K1 * sizeof(float);
BlockShape::Block_N1 * BlockShape::Block_K1 * sizeof(float);
return max(smem_mat_a + smem_mat_d, smem_bridge, smem_mat_o); return max(smem_mat_a + smem_mat_d, smem_bridge, smem_mat_o);
// return Policy::template GetSmemSize<Problem>(); // return Policy::template GetSmemSize<Problem>();
...@@ -237,7 +236,7 @@ struct FusedMoeGemmPipeline_General ...@@ -237,7 +236,7 @@ struct FusedMoeGemmPipeline_General
// } // }
// save to lds // save to lds
CK_TILE_LDS_ADDR ADataType* smem_y = reinterpret_cast<CK_TILE_LDS_ADDR YDataType*>(smem); CK_TILE_LDS_ADDR ADataType* smem_y = reinterpret_cast<CK_TILE_LDS_ADDR YDataType*>(smem);
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>( auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_y, Policy::template MakeBridgeLdsBlockDesc<Problem>()); smem_y, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_slds_win = auto bridge_slds_win =
make_tile_window(bridge_lds_view, make_tile_window(bridge_lds_view,
...@@ -304,20 +303,17 @@ struct FusedMoeGemmPipeline_General ...@@ -304,20 +303,17 @@ struct FusedMoeGemmPipeline_General
#endif #endif
// add to LDS // add to LDS
CK_TILE_LDS_ADDR float* smem_o = reinterpret_cast<CK_TILE_LDS_ADDR float*>(smem); CK_TILE_LDS_ADDR float* smem_o = reinterpret_cast<CK_TILE_LDS_ADDR float*>(smem);
auto o_lds_view = auto o_lds_view = make_tensor_view<address_space_enum::lds>(
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::set>( smem_o, Policy::template MakeLdsBlockDesc_O<Problem>());
smem_o, auto o_alds_win = make_tile_window(
make_tuple(number<128>{}, number<32>{}), o_lds_view,
make_tuple(32, 1), make_tuple(number<BlockShape::Block_K1>{}, number<BlockShape::Block_N1>{}),
number<8>{}, {0, 0});
number<1>{}); auto o_olds_win = make_tile_window(
auto o_alds_win = o_lds_view,
make_tile_window(o_lds_view, make_tuple(number<128>{}, number<32>{}), {0, 0}); make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
auto o_olds_win = {0, 0},
make_tile_window(o_lds_view, Policy::template MakeGlobalTileDistribution_O<Problem>());
make_tuple(number<32>{}, number<32>{}),
{0, 0},
Policy::template MakeGlobalTileDistribution_O<Problem>());
auto save_o = [&]() { auto save_o = [&]() {
// if(blockIdx.x == 0 && (blockIdx.y == 0 || blockIdx.y == 1) && blockIdx.z == 0) // if(blockIdx.x == 0 && (blockIdx.y == 0 || blockIdx.y == 1) && blockIdx.z == 0)
...@@ -338,7 +334,8 @@ struct FusedMoeGemmPipeline_General ...@@ -338,7 +334,8 @@ struct FusedMoeGemmPipeline_General
auto o = cast_tile<ODataType>(o0); auto o = cast_tile<ODataType>(o0);
update_tile(o_window_, o); update_tile(o_window_, o);
// restore pos // restore pos
move_tile_window(o_olds_win, {-32 * (BlockShape::Repeat_K1 - 1), 0}); move_tile_window(o_olds_win,
{-BlockShape::Block_M1 * (BlockShape::Repeat_K1 - 1), 0});
} }
} }
}; };
......
...@@ -134,7 +134,7 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -134,7 +134,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr index_t M_wav = NumWarps / K_wav; constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check"); static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav; constexpr index_t M_rep = MPerBlock / M_wav;
static_assert(M_rep <= 2);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>, sequence<1>,
...@@ -152,6 +152,7 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -152,6 +152,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
static_assert(MPerBlock % (M_lan * M_wav) == 0, static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check"); "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav); constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
static_assert(M_rep <= 2);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>, sequence<1>,
...@@ -354,6 +355,20 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -354,6 +355,20 @@ struct FusedMoeGemmPipelineGeneralPolicy
return d_lds_block_desc; return d_lds_block_desc;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_O()
{
constexpr index_t Block_N1 = Problem::BlockShape::Block_N1;
constexpr index_t Block_K1 = Problem::BlockShape::Block_N1;
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_K1>{}, number<Block_N1>{}),
make_tuple(number<Block_N1>{}, number<1>{}),
number<4>{},
number<1>{});
return desc;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc() CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment