Commit 7881eff9 authored by letaoqin's avatar letaoqin
Browse files

gemm down

parent 6a03c66f
......@@ -98,8 +98,6 @@ struct FusedMoeGemmPipeline_General
index_t hidden_size,
index_t intermediate_size)
{
ignore = d_window_;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
......@@ -194,12 +192,21 @@ struct FusedMoeGemmPipeline_General
while(iCounter1 > 0)
{
block_sync_lds();
gemm_1(o_acc, y, d);
block_sync_lds();
move_tile_window(d_global_to_dram_window, {kN1, 0});
d = load_tile(d_global_to_dram_window);
iCounter1--;
}
ignore = y;
ignore = d;
store_tile(o_window_, a_dram_block);
// tail
{
block_sync_lds();
gemm_1(o_acc, y, d);
}
auto o = cast_tile<ODataType>(o_acc);
store_tile(o_window_, o);
// store_tile(o_window_, a_dram_block);
#if 0
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
......
......@@ -12,6 +12,8 @@
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
namespace ck_tile {
......@@ -209,19 +211,41 @@ struct FusedMoeGemmPipelineGeneralPolicy
return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm1()
{
using S_ = typename Problem::BlockShape;
using GemmProblem = BlockGemmProblem<typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
S_::BlockSize,
TileGemmShape<typename S_::BlockTile_1,
typename S_::WarpPerBlock_1,
typename S_::WarpTile_1>>;
constexpr auto warp_gemm = GetWarpGemm1<Problem>();
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename S_::WarpPerBlock_1,
decltype(warp_gemm)>;
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
constexpr auto d_outer_dstr_enc =
tile_distribution_encoding<sequence<S_::WarpPerBlock_N1>,
tuple<sequence<S_::Repeat_N1>, sequence<S_::Repeat_K1>>,
tuple<sequence<0>>,
tuple<sequence<0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto d_outer_dstr_enc = tile_distribution_encoding<
sequence<S_::WarpPerBlock_M1>,
tuple<sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>, sequence<S_::Repeat_K1>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto d_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
d_outer_dstr_enc, typename WarpGemm::BWarpDstrEncoding{});
......@@ -368,13 +392,13 @@ struct FusedMoeGemmPipelineGeneralPolicy
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// TODO: all waves a along different N, but same M
constexpr auto y_outer_dstr_enc =
tile_distribution_encoding<sequence<S_::WarpPerBlock_M1>,
tuple<sequence<S_::Repeat_M1>, sequence<S_::Repeat_K1>>,
tuple<sequence<0>>,
tuple<sequence<0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto y_outer_dstr_enc = tile_distribution_encoding<
sequence<S_::WarpPerBlock_N1>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>, sequence<S_::Repeat_K1>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment