Commit fe44e66e authored by letaoqin's avatar letaoqin
Browse files

add gemm0 for tokens*G

parent f363ec7f
...@@ -57,7 +57,7 @@ struct indexing_adaptor_onshot_cached ...@@ -57,7 +57,7 @@ struct indexing_adaptor_onshot_cached
return ck_tile::is_known_at_compile_time<IndexingType>::value; return ck_tile::is_known_at_compile_time<IndexingType>::value;
} }
}; };
#define Using_Gather 1 #define Using_Gather 1
template <typename IndexingType> template <typename IndexingType>
struct indexing_adaptor struct indexing_adaptor
{ {
......
...@@ -125,11 +125,17 @@ struct FusedMoeGemmPipeline_General ...@@ -125,11 +125,17 @@ struct FusedMoeGemmPipeline_General
g_window_.get_window_origin(), g_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_G<Problem>()); Policy::template MakeGlobalTileDistribution_G<Problem>());
// Block GEMM
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
auto a_dram_block = load_tile(a_global_to_dram_window); auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block); store_tile(a_lds_win, a_dram_block);
auto g_dram_block = load_tile(g_global_to_dram_window); auto g_dram_block = load_tile(g_global_to_dram_window);
ignore = g_dram_block; ignore = g_dram_block;
ignore = s_acc;
store_tile(o_window_, a_dram_block); store_tile(o_window_, a_dram_block);
......
...@@ -8,6 +8,10 @@ ...@@ -8,6 +8,10 @@
#include "ck_tile/ops/flatmm.hpp" #include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -205,6 +209,30 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -205,6 +209,30 @@ struct FusedMoeGemmPipelineGeneralPolicy
sequence<0, 0, 2>>{}); sequence<0, 0, 2>>{});
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0()
{
using S_ = typename Problem::BlockShape;
using GemmProblem =
BlockGemmProblem<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
S_::BlockSize,
TileGemmShape<typename S_::BlockTile_0,
typename S_::WarpPerBlock_0,
typename S_::WarpTile_0>>;
constexpr auto warp_gemm = GetWarpGemm0<Problem>();
using BlockGemmPolicy =
BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename S_::WarpPerBlock_0,
decltype(warp_gemm)>;
return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{ {
...@@ -474,7 +502,7 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -474,7 +502,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
// TODO: ugly // TODO: ugly
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> && if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> && std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16) S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
{ {
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment