Commit ce97a2af authored by letaoqin's avatar letaoqin
Browse files

rewrite getsmemsize

parent e1b457ec
...@@ -228,7 +228,7 @@ struct FusedMoeGemmGlKernel ...@@ -228,7 +228,7 @@ struct FusedMoeGemmGlKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
// allocate LDS // allocate LDS
// __shared__ char smem_ptr[GetSmemSize()]; __shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane( IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
...@@ -236,8 +236,6 @@ struct FusedMoeGemmGlKernel ...@@ -236,8 +236,6 @@ struct FusedMoeGemmGlKernel
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size; index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size; index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index(block_m and // note this is in unit of tile, need multiple tile size to get the index(block_m and
// block_n) // block_n)
const auto [sorted_tile_id, intermediate_tile_id] = const auto [sorted_tile_id, intermediate_tile_id] =
......
...@@ -77,13 +77,14 @@ struct FusedMoeGemmPipeline_General ...@@ -77,13 +77,14 @@ struct FusedMoeGemmPipeline_General
{ {
// matrix a or tokens smem // matrix a or tokens smem
constexpr index_t smem_mat_a = GetSmemSizeA(); constexpr index_t smem_mat_a = GetSmemSizeA();
constexpr index_t smem_mat_d = constexpr index_t smem_mat_d = Policy::template GetSmemSize_G<Problem>();
BlockShape::Block_N0 * BlockShape::Block_K0 * sizeof(GDataType);
// shuffle C matrix // shuffle C matrix
constexpr index_t smem_bridge = constexpr index_t smem_bridge = Policy::template GetSmemSize_Bridge<Problem>();
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_mat_a + smem_mat_d, smem_bridge); constexpr index_t smem_mat_o =
BlockShape::Block_N1 * BlockShape::Block_K1 * sizeof(float);
return max(smem_mat_a + smem_mat_d, smem_bridge, smem_mat_o);
// return Policy::template GetSmemSize<Problem>(); // return Policy::template GetSmemSize<Problem>();
} }
...@@ -131,19 +132,19 @@ struct FusedMoeGemmPipeline_General ...@@ -131,19 +132,19 @@ struct FusedMoeGemmPipeline_General
index_t /*intermediate_size*/, index_t /*intermediate_size*/,
CWindow& /*c_window_*/) CWindow& /*c_window_*/)
{ {
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem); CK_TILE_LDS_ADDR ADataType* smem_a = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>( CK_TILE_LDS_ADDR GDataType* smem_g = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>(
smem_0 + GetSmemSizeA() / sizeof(ADataType)); smem_a + GetSmemSizeA() / sizeof(ADataType));
auto a_lds_view = make_tensor_view<address_space_enum::lds>( auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>()); smem_a, Policy::template MakeLdsBlockDesc_A<Problem>());
auto a_lds_win = make_tile_window( auto a_lds_win = make_tile_window(
a_lds_view, a_lds_view,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0}); {0, 0});
auto g_lds_view = make_tensor_view<address_space_enum::lds>( auto g_lds_view = make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsBlockDesc_G<Problem>()); smem_g, Policy::template MakeLdsBlockDesc_G<Problem>());
auto g_lds_win = make_tile_window( auto g_lds_win = make_tile_window(
g_lds_view, g_lds_view,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
...@@ -235,8 +236,9 @@ struct FusedMoeGemmPipeline_General ...@@ -235,8 +236,9 @@ struct FusedMoeGemmPipeline_General
// store_tile(c_window_, y_pre); // store_tile(c_window_, y_pre);
// } // }
// save to lds // save to lds
CK_TILE_LDS_ADDR ADataType* smem_y = reinterpret_cast<CK_TILE_LDS_ADDR YDataType*>(smem);
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>( auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>()); smem_y, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_slds_win = auto bridge_slds_win =
make_tile_window(bridge_lds_view, make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(), Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
...@@ -285,7 +287,7 @@ struct FusedMoeGemmPipeline_General ...@@ -285,7 +287,7 @@ struct FusedMoeGemmPipeline_General
{ {
for(int i = 0; i < 16; i++) for(int i = 0; i < 16; i++)
{ {
printf("\n smem_0[%d]: %f ", i, type_convert<float>(smem_0[i])); printf("\n smem_a[%d]: %f ", i, type_convert<float>(smem_a[i]));
} }
} }
//store_tile(c_window_, y); //store_tile(c_window_, y);
...@@ -301,10 +303,10 @@ struct FusedMoeGemmPipeline_General ...@@ -301,10 +303,10 @@ struct FusedMoeGemmPipeline_General
PrintMem(d,"D",0); PrintMem(d,"D",0);
#endif #endif
// add to LDS // add to LDS
CK_TILE_LDS_ADDR float* smem_3 = reinterpret_cast<CK_TILE_LDS_ADDR float*>(smem); CK_TILE_LDS_ADDR float* smem_o = reinterpret_cast<CK_TILE_LDS_ADDR float*>(smem);
auto o_lds_view = auto o_lds_view =
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::set>( make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::set>(
smem_3, smem_o,
make_tuple(number<128>{}, number<32>{}), make_tuple(number<128>{}, number<32>{}),
make_tuple(32, 1), make_tuple(32, 1),
number<8>{}, number<8>{},
......
...@@ -94,14 +94,21 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -94,14 +94,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{ {
constexpr auto a_lds_desc = MakeLdsBlockDesc_A<Problem>(); constexpr auto a_lds_desc = MakeLdsBlockDesc_A<Problem>();
return a_lds_desc.get_element_space_size(); return a_lds_desc.get_element_space_size() * sizeof(typename Problem::ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_G()
{
constexpr auto g_lds_desc = MakeLdsBlockDesc_G<Problem>();
return g_lds_desc.get_element_space_size() * sizeof(typename Problem::GDataType);
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
{ {
constexpr auto bridge_lds_desc = MakeBridgeLdsBlockDesc<Problem>(); constexpr auto bridge_lds_desc = MakeBridgeLdsBlockDesc<Problem>();
return bridge_lds_desc.get_element_space_size(); return bridge_lds_desc.get_element_space_size() * sizeof(typename Problem::YDataType);
} }
template <typename Problem> template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment