Commit 66efcf96 authored by letaoqin's avatar letaoqin
Browse files

change g tile distribution

parent fe44e66e
...@@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General ...@@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General
ignore = s_acc; ignore = s_acc;
store_tile(o_window_, a_dram_block); store_tile(o_window_, a_dram_block);
#if 0 #if 0
//check a matrix gather right or not //check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans(); constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
......
...@@ -18,6 +18,7 @@ namespace ck_tile { ...@@ -18,6 +18,7 @@ namespace ck_tile {
struct FusedMoeGemmPipelineGeneralPolicy struct FusedMoeGemmPipelineGeneralPolicy
{ {
static constexpr int kKIter = 2; static constexpr int kKIter = 2;
static constexpr int kKPerBlock = 32;
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords() CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{ {
...@@ -198,13 +199,17 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -198,13 +199,17 @@ struct FusedMoeGemmPipelineGeneralPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{ {
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
constexpr index_t K2 = S_::Warp_K0;
constexpr index_t K1 = get_warp_size() / S_::Warp_N0;
constexpr index_t K0 = kKPerBlock / (K1 * K2);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>, sequence<1>,
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0, S_::Warp_N0>, tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0, S_::Warp_N0>,
sequence<kKIter, get_warp_size() / S_::Warp_N0, S_::Warp_K0>>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>, tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>, sequence<1, 2, 2>,
sequence<0, 0, 2>>{}); sequence<0, 0, 2>>{});
} }
...@@ -213,8 +218,7 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -213,8 +218,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0()
{ {
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
using GemmProblem = using GemmProblem = BlockGemmProblem<typename Problem::ADataType,
BlockGemmProblem<typename Problem::ADataType,
typename Problem::GDataType, typename Problem::GDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
S_::BlockSize, S_::BlockSize,
...@@ -223,8 +227,7 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -223,8 +227,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
typename S_::WarpTile_0>>; typename S_::WarpTile_0>>;
constexpr auto warp_gemm = GetWarpGemm0<Problem>(); constexpr auto warp_gemm = GetWarpGemm0<Problem>();
using BlockGemmPolicy = using BlockGemmPolicy = BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::GDataType, typename Problem::GDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
typename S_::WarpPerBlock_0, typename S_::WarpPerBlock_0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment