"vscode:/vscode.git/clone" did not exist on "119d734f6ec0b46866aa90b19770dc599fff04ef"
Commit 66efcf96 authored by letaoqin's avatar letaoqin
Browse files

change g tile distribution

parent fe44e66e
...@@ -126,7 +126,7 @@ struct FusedMoeGemmPipeline_General ...@@ -126,7 +126,7 @@ struct FusedMoeGemmPipeline_General
Policy::template MakeGlobalTileDistribution_G<Problem>()); Policy::template MakeGlobalTileDistribution_G<Problem>());
// Block GEMM // Block GEMM
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>(); constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{}; auto s_acc = SaccBlockTileType{};
...@@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General ...@@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General
ignore = s_acc; ignore = s_acc;
store_tile(o_window_, a_dram_block); store_tile(o_window_, a_dram_block);
#if 0 #if 0
//check a matrix gather right or not //check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans(); constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
......
...@@ -17,7 +17,8 @@ namespace ck_tile { ...@@ -17,7 +17,8 @@ namespace ck_tile {
struct FusedMoeGemmPipelineGeneralPolicy struct FusedMoeGemmPipelineGeneralPolicy
{ {
static constexpr int kKIter = 2; static constexpr int kKIter = 2;
static constexpr int kKPerBlock = 32;
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords() CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{ {
...@@ -197,14 +198,18 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -197,14 +198,18 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{ {
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
constexpr index_t K2 = S_::Warp_K0;
constexpr index_t K1 = get_warp_size() / S_::Warp_N0;
constexpr index_t K0 = kKPerBlock / (K1 * K2);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>, sequence<1>,
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0, S_::Warp_N0>, tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0, S_::Warp_N0>,
sequence<kKIter, get_warp_size() / S_::Warp_N0, S_::Warp_K0>>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>, tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>, sequence<1, 2, 2>,
sequence<0, 0, 2>>{}); sequence<0, 0, 2>>{});
} }
...@@ -212,23 +217,21 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -212,23 +217,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0()
{ {
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
using GemmProblem = using GemmProblem = BlockGemmProblem<typename Problem::ADataType,
BlockGemmProblem<typename Problem::ADataType, typename Problem::GDataType,
typename Problem::GDataType, typename Problem::AccDataType,
typename Problem::AccDataType, S_::BlockSize,
S_::BlockSize, TileGemmShape<typename S_::BlockTile_0,
TileGemmShape<typename S_::BlockTile_0, typename S_::WarpPerBlock_0,
typename S_::WarpPerBlock_0, typename S_::WarpTile_0>>;
typename S_::WarpTile_0>>;
constexpr auto warp_gemm = GetWarpGemm0<Problem>(); constexpr auto warp_gemm = GetWarpGemm0<Problem>();
using BlockGemmPolicy = using BlockGemmPolicy = BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType, typename Problem::GDataType,
typename Problem::GDataType, typename Problem::AccDataType,
typename Problem::AccDataType, typename S_::WarpPerBlock_0,
typename S_::WarpPerBlock_0, decltype(warp_gemm)>;
decltype(warp_gemm)>;
return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{}; return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment