Commit 6a03c66f authored by letaoqin's avatar letaoqin
Browse files

start gemm down

parent b2030e34
...@@ -319,13 +319,11 @@ struct FusedMoeGemmGlKernel ...@@ -319,13 +319,11 @@ struct FusedMoeGemmGlKernel
const auto d_window = [&]() { const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) + const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 + static_cast<long_index_t>(expert_id) * expert_stride_1;
idx_n0;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm // note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>( const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr, d_ptr,
make_tuple(kargs.hidden_size, BlockShape::Block_K1), make_tuple(kargs.hidden_size, kargs.intermediate_size),
make_tuple(kargs.intermediate_size, 1), make_tuple(kargs.intermediate_size, 1),
number<Pipeline::kAlignmentD>{}, number<Pipeline::kAlignmentD>{},
number<1>{}); number<1>{});
...@@ -333,7 +331,7 @@ struct FusedMoeGemmGlKernel ...@@ -333,7 +331,7 @@ struct FusedMoeGemmGlKernel
const auto d_window_ = make_tile_window( const auto d_window_ = make_tile_window(
d_view_, d_view_,
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}), make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
{0, 0}); {0, idx_n0});
return d_window_; return d_window_;
}(); }();
......
...@@ -99,8 +99,6 @@ struct FusedMoeGemmPipeline_General ...@@ -99,8 +99,6 @@ struct FusedMoeGemmPipeline_General
index_t intermediate_size) index_t intermediate_size)
{ {
ignore = d_window_; ignore = d_window_;
ignore = o_window_;
ignore = hidden_size;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem); CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>( auto a_lds_view = make_tensor_view<address_space_enum::lds>(
...@@ -137,8 +135,8 @@ struct FusedMoeGemmPipeline_General ...@@ -137,8 +135,8 @@ struct FusedMoeGemmPipeline_General
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
constexpr index_t kK0 = BlockShape::Block_K0; constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0); const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0);
index_t iCounter = k0_loops - 1; index_t iCounter0 = k0_loops - 1;
while(iCounter > 0) while(iCounter0 > 0)
{ {
block_sync_lds(); block_sync_lds();
...@@ -152,7 +150,7 @@ struct FusedMoeGemmPipeline_General ...@@ -152,7 +150,7 @@ struct FusedMoeGemmPipeline_General
g_dram_block = load_tile(g_global_to_dram_window); g_dram_block = load_tile(g_global_to_dram_window);
store_tile(a_lds_win, a_dram_block); store_tile(a_lds_win, a_dram_block);
iCounter--; iCounter0--;
} }
// tail // tail
{ {
...@@ -162,16 +160,45 @@ struct FusedMoeGemmPipeline_General ...@@ -162,16 +160,45 @@ struct FusedMoeGemmPipeline_General
// move sacc to LDS // move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>( auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>()); smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_lds_win = auto bridge_slds_win =
make_tile_window(bridge_lds_view, make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(), Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0}); {0, 0});
auto y_pre = cast_tile<YDataType>(s_acc); auto y_pre = cast_tile<YDataType>(s_acc);
store_tile(bridge_lds_win, y_pre); store_tile(bridge_slds_win, y_pre);
// gemm1 down block_sync_lds();
// gemm down
constexpr auto gemm_1 = Policy::template GetBlockGemm1<Problem>();
using SaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = SaccBlockTileType{};
// y data
auto bridge_llds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0},
Policy::template MakeYTileDistribution<Problem>());
auto y = load_tile(bridge_llds_win);
// d data
auto d_global_to_dram_window = make_tile_window(
d_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
d_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_D<Problem>());
auto d = load_tile(d_global_to_dram_window);
constexpr index_t kN1 = BlockShape::Block_N1;
const index_t n1_loops = ck_tile::integer_divide_ceil(hidden_size, kN1);
index_t iCounter1 = n1_loops - 1;
while(iCounter1 > 0)
{
block_sync_lds();
ignore = bridge_lds_win; iCounter1--;
}
ignore = y;
ignore = d;
store_tile(o_window_, a_dram_block); store_tile(o_window_, a_dram_block);
#if 0 #if 0
//check a matrix gather right or not //check a matrix gather right or not
......
...@@ -158,25 +158,6 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -158,25 +158,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
} }
} }
template <index_t WarpPerBlock_N_,
index_t WarpPerBlock_K_,
index_t Repeat_N_,
index_t Repeat_K_,
index_t WarpSize_,
index_t Alignment_>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_Nr_Kr_W()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Repeat_N_, WarpPerBlock_N_>,
sequence<Repeat_K_, WarpPerBlock_K_>,
sequence<WarpSize_, Alignment_>>,
tuple<sequence<1, 2>, sequence<3>>,
tuple<sequence<1, 1>, sequence<0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{});
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
{ {
...@@ -231,17 +212,21 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -231,17 +212,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{ {
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum; using S_ = remove_cvref_t<typename Problem::BlockShape>;
using S_ = typename Problem::BlockShape; using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{ constexpr auto d_outer_dstr_enc =
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N1, tile_distribution_encoding<sequence<S_::WarpPerBlock_N1>,
S_::WarpPerBlock_K1, tuple<sequence<S_::Repeat_N1>, sequence<S_::Repeat_K1>>,
S_::Repeat_N1, tuple<sequence<0>>,
S_::Repeat_K1, tuple<sequence<0>>,
get_warp_size(), sequence<1, 2>,
GetAlignment_D<Problem>()>(); sequence<0, 0>>{};
}
constexpr auto d_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
d_outer_dstr_enc, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto d_block_dstr = make_static_tile_distribution(d_block_dstr_encode);
return d_block_dstr;
} }
template <typename Problem> template <typename Problem>
...@@ -375,50 +360,26 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -375,50 +360,26 @@ struct FusedMoeGemmPipelineGeneralPolicy
} }
} }
// this is used as A matrix for 2nd gemm
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0() CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
{ {
using S_ = remove_cvref_t<typename Problem::BlockShape>; using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>; using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M0, S_::WarpPerBlock_M0>,
sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding = // TODO: all waves a along different N, but same M
tile_distribution_encoding<sequence<>, constexpr auto y_outer_dstr_enc =
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>, tile_distribution_encoding<sequence<S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>, tuple<sequence<S_::Repeat_M1>, sequence<S_::Repeat_K1>>,
tuple<sequence<1, 2>>, tuple<sequence<0>>,
tuple<sequence<1, 1>>, tuple<sequence<0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr); return y_block_dstr;
return c_block_tensor;
} }
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment