Commit b2030e34 authored by letaoqin's avatar letaoqin
Browse files

s_acc data to lds to shuffle

parent 1d89463c
...@@ -122,7 +122,7 @@ struct FusedMoeGemmPipeline_General ...@@ -122,7 +122,7 @@ struct FusedMoeGemmPipeline_General
g_window_.get_window_origin(), g_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_G<Problem>()); Policy::template MakeGlobalTileDistribution_G<Problem>());
// Block GEMM // gemm gate
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>(); constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{}; auto s_acc = SaccBlockTileType{};
...@@ -135,36 +135,43 @@ struct FusedMoeGemmPipeline_General ...@@ -135,36 +135,43 @@ struct FusedMoeGemmPipeline_General
auto g_dram_block = load_tile(g_global_to_dram_window); auto g_dram_block = load_tile(g_global_to_dram_window);
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
constexpr index_t kK0 = BlockShape::Block_K0; constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0); const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0);
index_t iCounter = k0_loops - 1; index_t iCounter = k0_loops - 1;
//gemm 0
while(iCounter > 0) while(iCounter > 0)
{ {
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block); gemm_0(s_acc, a_lds_win, g_dram_block);
block_sync_lds(); block_sync_lds();
move_tile_window(a_global_to_dram_window, {0, kK0}); move_tile_window(a_global_to_dram_window, {0, kK0});
move_tile_window(g_global_to_dram_window, {0, kK0}); move_tile_window(g_global_to_dram_window, {0, kK0});
a_dram_block = load_tile(a_global_to_dram_window); a_dram_block = load_tile(a_global_to_dram_window);
g_dram_block = load_tile(g_global_to_dram_window); g_dram_block = load_tile(g_global_to_dram_window);
store_tile(a_lds_win, a_dram_block); store_tile(a_lds_win, a_dram_block);
iCounter--; iCounter--;
} }
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block); gemm_0(s_acc, a_lds_win, g_dram_block);
} }
// move sacc to LDS
//move sacc to LDS auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
ignore = g_dram_block; auto bridge_lds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0});
auto y_pre = cast_tile<YDataType>(s_acc);
store_tile(bridge_lds_win, y_pre);
// gemm1 down
ignore = bridge_lds_win;
store_tile(o_window_, a_dram_block); store_tile(o_window_, a_dram_block);
#if 0 #if 0
//check a matrix gather right or not //check a matrix gather right or not
......
...@@ -102,11 +102,8 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -102,11 +102,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
{ {
constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>(); constexpr auto bridge_lds_desc = MakeBridgeLdsBlockDesc<Problem>();
constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>(); return bridge_lds_desc.get_element_space_size();
static_assert(bridge_sld_desc.get_element_space_size() ==
bridge_sst_desc.get_element_space_size());
return bridge_sld_desc.get_element_space_size();
} }
template <typename Problem> template <typename Problem>
...@@ -296,30 +293,13 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -296,30 +293,13 @@ struct FusedMoeGemmPipelineGeneralPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc() CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc()
{ {
constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0; constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword constexpr index_t KVector = GetSmemKPack_Y<Problem>();
constexpr index_t KPad = 0; // pad between warps constexpr index_t KPad = 0;
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = 0; // KVector; // pad between warps
constexpr auto desc = constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}), make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment