Commit ce97a2af authored by letaoqin's avatar letaoqin
Browse files

rewrite getsmemsize

parent e1b457ec
......@@ -228,7 +228,7 @@ struct FusedMoeGemmGlKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
__shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
......@@ -236,8 +236,6 @@ struct FusedMoeGemmGlKernel
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index(block_m and
// block_n)
const auto [sorted_tile_id, intermediate_tile_id] =
......
......@@ -77,13 +77,14 @@ struct FusedMoeGemmPipeline_General
{
// matrix a or tokens smem
constexpr index_t smem_mat_a = GetSmemSizeA();
constexpr index_t smem_mat_d =
BlockShape::Block_N0 * BlockShape::Block_K0 * sizeof(GDataType);
constexpr index_t smem_mat_d = Policy::template GetSmemSize_G<Problem>();
// shuffle C matrix
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
constexpr index_t smem_bridge = Policy::template GetSmemSize_Bridge<Problem>();
return max(smem_mat_a + smem_mat_d, smem_bridge);
constexpr index_t smem_mat_o =
BlockShape::Block_N1 * BlockShape::Block_K1 * sizeof(float);
return max(smem_mat_a + smem_mat_d, smem_bridge, smem_mat_o);
// return Policy::template GetSmemSize<Problem>();
}
......@@ -131,19 +132,19 @@ struct FusedMoeGemmPipeline_General
index_t /*intermediate_size*/,
CWindow& /*c_window_*/)
{
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>(
smem_0 + GetSmemSizeA() / sizeof(ADataType));
CK_TILE_LDS_ADDR ADataType* smem_a = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR GDataType* smem_g = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>(
smem_a + GetSmemSizeA() / sizeof(ADataType));
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
smem_a, Policy::template MakeLdsBlockDesc_A<Problem>());
auto a_lds_win = make_tile_window(
a_lds_view,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
auto g_lds_view = make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsBlockDesc_G<Problem>());
smem_g, Policy::template MakeLdsBlockDesc_G<Problem>());
auto g_lds_win = make_tile_window(
g_lds_view,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
......@@ -235,8 +236,9 @@ struct FusedMoeGemmPipeline_General
// store_tile(c_window_, y_pre);
// }
// save to lds
CK_TILE_LDS_ADDR ADataType* smem_y = reinterpret_cast<CK_TILE_LDS_ADDR YDataType*>(smem);
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
smem_y, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_slds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
......@@ -285,7 +287,7 @@ struct FusedMoeGemmPipeline_General
{
for(int i = 0; i < 16; i++)
{
printf("\n smem_0[%d]: %f ", i, type_convert<float>(smem_0[i]));
printf("\n smem_a[%d]: %f ", i, type_convert<float>(smem_a[i]));
}
}
//store_tile(c_window_, y);
......@@ -301,10 +303,10 @@ struct FusedMoeGemmPipeline_General
PrintMem(d,"D",0);
#endif
// add to LDS
CK_TILE_LDS_ADDR float* smem_3 = reinterpret_cast<CK_TILE_LDS_ADDR float*>(smem);
CK_TILE_LDS_ADDR float* smem_o = reinterpret_cast<CK_TILE_LDS_ADDR float*>(smem);
auto o_lds_view =
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::set>(
smem_3,
smem_o,
make_tuple(number<128>{}, number<32>{}),
make_tuple(32, 1),
number<8>{},
......
......@@ -94,14 +94,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
constexpr auto a_lds_desc = MakeLdsBlockDesc_A<Problem>();
return a_lds_desc.get_element_space_size();
return a_lds_desc.get_element_space_size() * sizeof(typename Problem::ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_G()
{
constexpr auto g_lds_desc = MakeLdsBlockDesc_G<Problem>();
return g_lds_desc.get_element_space_size() * sizeof(typename Problem::GDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
{
constexpr auto bridge_lds_desc = MakeBridgeLdsBlockDesc<Problem>();
return bridge_lds_desc.get_element_space_size();
return bridge_lds_desc.get_element_space_size() * sizeof(typename Problem::YDataType);
}
template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment