Commit 7018dfb2 authored by letaoqin's avatar letaoqin
Browse files

start gemm0

parent 9ec586fc
......@@ -107,7 +107,7 @@ struct FusedMoeGemmPipeline_General
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>());
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
auto a_lds_win = make_tile_window(
a_lds_view,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
......@@ -130,12 +130,18 @@ struct FusedMoeGemmPipeline_General
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// save tokens to lds
auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
// load g to register
auto g_dram_block = load_tile(g_global_to_dram_window);
ignore = g_dram_block;
ignore = s_acc;
clear_tile(s_acc); // initialize C
gemm_0(s_acc, a_lds_win, g_dram_block);
ignore = g_dram_block;
store_tile(o_window_, a_dram_block);
#if 0
......
......@@ -17,9 +17,6 @@ namespace ck_tile {
struct FusedMoeGemmPipelineGeneralPolicy
{
static constexpr int kKIter = 2;
static constexpr int kKPerBlock = 32;
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{
// TODO: always 1 dword
......@@ -98,10 +95,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
constexpr auto a_sld_desc = MakeLdsLoadDesc_A<Problem>();
constexpr auto a_sst_desc = MakeLdsStoreDesc_A<Problem>();
static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
return a_sld_desc.get_element_space_size();
constexpr auto a_lds_desc = MakeLdsBlockDesc_A<Problem>();
return a_lds_desc.get_element_space_size();
}
template <typename Problem>
......@@ -198,20 +193,20 @@ struct FusedMoeGemmPipelineGeneralPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{
using S_ = typename Problem::BlockShape;
constexpr index_t K2 = S_::Warp_K0;
constexpr index_t K1 = get_warp_size() / S_::Warp_N0;
constexpr index_t K0 = S_::Repeat_K0;
using WG = decltype(GetWarpGemm0<Problem>());
using S_ = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0, S_::Warp_N0>,
sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
tuple<sequence<1>>,
tuple<sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
return make_static_tile_distribution(g_block_dstr_encode);
}
template <typename Problem>
......@@ -275,7 +270,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_A()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
......@@ -300,101 +295,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK >= NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
make_merge_transform(make_tuple(
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment