Commit b2030e34 authored by letaoqin's avatar letaoqin
Browse files

s_acc data to lds to shuffle

parent 1d89463c
......@@ -122,7 +122,7 @@ struct FusedMoeGemmPipeline_General
g_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_G<Problem>());
// Block GEMM
// gemm gate
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
......@@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General
constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0);
index_t iCounter = k0_loops - 1;
//gemm 0
while(iCounter > 0)
{
block_sync_lds();
......@@ -160,11 +159,19 @@ struct FusedMoeGemmPipeline_General
block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block);
}
// move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_lds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0});
//move sacc to LDS
ignore = g_dram_block;
auto y_pre = cast_tile<YDataType>(s_acc);
store_tile(bridge_lds_win, y_pre);
// gemm1 down
ignore = bridge_lds_win;
store_tile(o_window_, a_dram_block);
#if 0
//check a matrix gather right or not
......
......@@ -102,11 +102,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
{
constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>();
constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>();
static_assert(bridge_sld_desc.get_element_space_size() ==
bridge_sst_desc.get_element_space_size());
return bridge_sld_desc.get_element_space_size();
constexpr auto bridge_lds_desc = MakeBridgeLdsBlockDesc<Problem>();
return bridge_lds_desc.get_element_space_size();
}
template <typename Problem>
......@@ -296,30 +293,13 @@ struct FusedMoeGemmPipelineGeneralPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc()
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = 0; // pad between warps
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = 0; // KVector; // pad between warps
constexpr index_t KVector = GetSmemKPack_Y<Problem>();
constexpr index_t KPad = 0;
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment