Commit 66efcf96 authored by letaoqin's avatar letaoqin
Browse files

change g tile distribution

parent fe44e66e
......@@ -126,7 +126,7 @@ struct FusedMoeGemmPipeline_General
Policy::template MakeGlobalTileDistribution_G<Problem>());
// Block GEMM
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
......@@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General
ignore = s_acc;
store_tile(o_window_, a_dram_block);
#if 0
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
......
......@@ -17,7 +17,8 @@ namespace ck_tile {
struct FusedMoeGemmPipelineGeneralPolicy
{
static constexpr int kKIter = 2;
static constexpr int kKIter = 2;
static constexpr int kKPerBlock = 32;
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{
......@@ -197,14 +198,18 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{
using S_ = typename Problem::BlockShape;
using S_ = typename Problem::BlockShape;
constexpr index_t K2 = S_::Warp_K0;
constexpr index_t K1 = get_warp_size() / S_::Warp_N0;
constexpr index_t K0 = kKPerBlock / (K1 * K2);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0, S_::Warp_N0>,
sequence<kKIter, get_warp_size() / S_::Warp_N0, S_::Warp_K0>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
......@@ -212,23 +217,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0()
{
using S_ = typename Problem::BlockShape;
using GemmProblem =
BlockGemmProblem<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
S_::BlockSize,
TileGemmShape<typename S_::BlockTile_0,
typename S_::WarpPerBlock_0,
typename S_::WarpTile_0>>;
using S_ = typename Problem::BlockShape;
using GemmProblem = BlockGemmProblem<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
S_::BlockSize,
TileGemmShape<typename S_::BlockTile_0,
typename S_::WarpPerBlock_0,
typename S_::WarpTile_0>>;
constexpr auto warp_gemm = GetWarpGemm0<Problem>();
using BlockGemmPolicy =
BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename S_::WarpPerBlock_0,
decltype(warp_gemm)>;
using BlockGemmPolicy = BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename S_::WarpPerBlock_0,
decltype(warp_gemm)>;
return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment