Commit 6a03c66f authored by letaoqin's avatar letaoqin
Browse files

start gemm down

parent b2030e34
......@@ -319,13 +319,11 @@ struct FusedMoeGemmGlKernel
const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
idx_n0;
static_cast<long_index_t>(expert_id) * expert_stride_1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(kargs.hidden_size, BlockShape::Block_K1),
make_tuple(kargs.hidden_size, kargs.intermediate_size),
make_tuple(kargs.intermediate_size, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
......@@ -333,7 +331,7 @@ struct FusedMoeGemmGlKernel
const auto d_window_ = make_tile_window(
d_view_,
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
{0, 0});
{0, idx_n0});
return d_window_;
}();
......
......@@ -99,8 +99,6 @@ struct FusedMoeGemmPipeline_General
index_t intermediate_size)
{
ignore = d_window_;
ignore = o_window_;
ignore = hidden_size;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
......@@ -137,8 +135,8 @@ struct FusedMoeGemmPipeline_General
clear_tile(s_acc); // initialize C
constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0);
index_t iCounter = k0_loops - 1;
while(iCounter > 0)
index_t iCounter0 = k0_loops - 1;
while(iCounter0 > 0)
{
block_sync_lds();
......@@ -152,7 +150,7 @@ struct FusedMoeGemmPipeline_General
g_dram_block = load_tile(g_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
iCounter--;
iCounter0--;
}
// tail
{
......@@ -162,16 +160,45 @@ struct FusedMoeGemmPipeline_General
// move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_lds_win =
auto bridge_slds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0});
auto y_pre = cast_tile<YDataType>(s_acc);
store_tile(bridge_lds_win, y_pre);
// gemm1 down
store_tile(bridge_slds_win, y_pre);
block_sync_lds();
// gemm down
constexpr auto gemm_1 = Policy::template GetBlockGemm1<Problem>();
using SaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = SaccBlockTileType{};
// y data
auto bridge_llds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0},
Policy::template MakeYTileDistribution<Problem>());
auto y = load_tile(bridge_llds_win);
// d data
auto d_global_to_dram_window = make_tile_window(
d_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
d_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_D<Problem>());
auto d = load_tile(d_global_to_dram_window);
constexpr index_t kN1 = BlockShape::Block_N1;
const index_t n1_loops = ck_tile::integer_divide_ceil(hidden_size, kN1);
index_t iCounter1 = n1_loops - 1;
while(iCounter1 > 0)
{
block_sync_lds();
ignore = bridge_lds_win;
iCounter1--;
}
ignore = y;
ignore = d;
store_tile(o_window_, a_dram_block);
#if 0
//check a matrix gather right or not
......
......@@ -158,25 +158,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
}
}
template <index_t WarpPerBlock_N_,
index_t WarpPerBlock_K_,
index_t Repeat_N_,
index_t Repeat_K_,
index_t WarpSize_,
index_t Alignment_>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_Nr_Kr_W()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Repeat_N_, WarpPerBlock_N_>,
sequence<Repeat_K_, WarpPerBlock_K_>,
sequence<WarpSize_, Alignment_>>,
tuple<sequence<1, 2>, sequence<3>>,
tuple<sequence<1, 1>, sequence<0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
{
......@@ -231,17 +212,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N1,
S_::WarpPerBlock_K1,
S_::Repeat_N1,
S_::Repeat_K1,
get_warp_size(),
GetAlignment_D<Problem>()>();
}
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
constexpr auto d_outer_dstr_enc =
tile_distribution_encoding<sequence<S_::WarpPerBlock_N1>,
tuple<sequence<S_::Repeat_N1>, sequence<S_::Repeat_K1>>,
tuple<sequence<0>>,
tuple<sequence<0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto d_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
d_outer_dstr_enc, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto d_block_dstr = make_static_tile_distribution(d_block_dstr_encode);
return d_block_dstr;
}
template <typename Problem>
......@@ -375,50 +360,26 @@ struct FusedMoeGemmPipelineGeneralPolicy
}
}
// this is used as A matrix for 2nd gemm
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0()
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>;
using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M0, S_::WarpPerBlock_M0>,
sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
using CDataType = typename WarpGemm::CDataType;
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
// TODO: all waves a along different N, but same M
constexpr auto y_outer_dstr_enc =
tile_distribution_encoding<sequence<S_::WarpPerBlock_M1>,
tuple<sequence<S_::Repeat_M1>, sequence<S_::Repeat_K1>>,
tuple<sequence<0>>,
tuple<sequence<0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
return y_block_dstr;
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment