Commit b2030e34 authored by letaoqin's avatar letaoqin
Browse files

s_acc data to lds to shuffle

parent 1d89463c
...@@ -122,7 +122,7 @@ struct FusedMoeGemmPipeline_General ...@@ -122,7 +122,7 @@ struct FusedMoeGemmPipeline_General
g_window_.get_window_origin(), g_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_G<Problem>()); Policy::template MakeGlobalTileDistribution_G<Problem>());
// Block GEMM // gemm gate
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>(); constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{}; auto s_acc = SaccBlockTileType{};
...@@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General ...@@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General
constexpr index_t kK0 = BlockShape::Block_K0; constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0); const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0);
index_t iCounter = k0_loops - 1; index_t iCounter = k0_loops - 1;
//gemm 0
while(iCounter > 0) while(iCounter > 0)
{ {
block_sync_lds(); block_sync_lds();
...@@ -160,11 +159,19 @@ struct FusedMoeGemmPipeline_General ...@@ -160,11 +159,19 @@ struct FusedMoeGemmPipeline_General
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block); gemm_0(s_acc, a_lds_win, g_dram_block);
} }
// move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_lds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0});
//move sacc to LDS auto y_pre = cast_tile<YDataType>(s_acc);
store_tile(bridge_lds_win, y_pre);
ignore = g_dram_block; // gemm1 down
ignore = bridge_lds_win;
store_tile(o_window_, a_dram_block); store_tile(o_window_, a_dram_block);
#if 0 #if 0
//check a matrix gather right or not //check a matrix gather right or not
......
...@@ -102,11 +102,8 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -102,11 +102,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
{ {
constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>(); constexpr auto bridge_lds_desc = MakeBridgeLdsBlockDesc<Problem>();
constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>(); return bridge_lds_desc.get_element_space_size();
static_assert(bridge_sld_desc.get_element_space_size() ==
bridge_sst_desc.get_element_space_size());
return bridge_sld_desc.get_element_space_size();
} }
template <typename Problem> template <typename Problem>
...@@ -296,30 +293,13 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -296,30 +293,13 @@ struct FusedMoeGemmPipelineGeneralPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc() CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc()
{ {
constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0; constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword constexpr index_t KVector = GetSmemKPack_Y<Problem>();
constexpr index_t KPad = 0; // pad between warps constexpr index_t KPad = 0;
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = 0; // KVector; // pad between warps
constexpr auto desc = constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}), make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment