Commit 1ba8a08f authored by carlushuang's avatar carlushuang
Browse files

update tmp work

parent bf214665
......@@ -252,7 +252,7 @@ struct FusedMoeKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// __shared__ char smem_ptr[GetSmemSize()];
ck_tile::index_t num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const ck_tile::index_t*>(kargs.num_sorted_tiles_ptr));
ck_tile::index_t tile_id = __builtin_amdgcn_readfirstlane(blockIdx.x;);
......@@ -436,8 +436,7 @@ struct FusedMoeKernel
u_gtile_window,
d_gtile_window,
o_gtile_window,
scale,
smem_ptr);
scale);
tile_id += gridDim.x;
}
......
......@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetAIndex()
{
constexpr auto a_dist = Policy::template MakeAGlobalTileDistribution<Problem>();
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
......@@ -142,7 +142,8 @@ struct BlockFmhaPipelineQRKSVSAsync
OGlobalTensorView& o_gtile_window_tmp,
// const void * sorted_weight_ptr,
ScaleDataType scale,
void* smem_ptr,
CK_TILE_LDS_ADDR void* smem_0,
CK_TILE_LDS_ADDR void* smem_1,
index_t dim_size,
index_t hidden_size)
{
......@@ -153,25 +154,25 @@ struct BlockFmhaPipelineQRKSVSAsync
make_tile_window(a_gtile_window_tmp.get_bottom_tensor_view(),
a_gtile_window_tmp.get_window_lengths(),
a_gtile_window_tmp.get_window_origin(),
Policy::template MakeAGlobalTileDistribution<Problem>());
Policy::template MakeGlobalTileDistribution_A<Problem>());
auto g_gtile_window =
make_tile_window(g_gtile_window_tmp.get_bottom_tensor_view(),
g_gtile_window_tmp.get_window_lengths(),
g_gtile_window_tmp.get_window_origin(),
Policy::template MakeGGlobalTileDistribution<Problem>());
Policy::template MakeGlobalTileDistribution_G<Problem>());
auto u_gtile_window =
make_tile_window(u_gtile_window_tmp.get_bottom_tensor_view(),
u_gtile_window_tmp.get_window_lengths(),
u_gtile_window_tmp.get_window_origin(),
Policy::template MakeUGlobalTileDistribution<Problem>());
Policy::template MakeGlobalTileDistribution_U<Problem>());
auto d_gtile_window =
make_tile_window(d_gtile_window_tmp.get_bottom_tensor_view(),
d_gtile_window_tmp.get_window_lengths(),
d_gtile_window_tmp.get_window_origin(),
Policy::template MakeDGlobalTileDistribution<Problem>());
Policy::template MakeGlobalTileDistribution_D<Problem>());
auto o_gtile_window =
make_tile_window(o_gtile_window_tmp.get_bottom_tensor_view(),
......@@ -187,12 +188,13 @@ struct BlockFmhaPipelineQRKSVSAsync
auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
make_tile_window(make_tensor_view<address_space_enum::lds>(
a_smem_ptr, Policy::template MakeALdsStoreBlockDescriptor<Problem>()),
Policy::template MakeALdsStoreBlockDescriptor<Problem>().get_lengths(),
auto smem_0_window = make_tile_window(
make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreBlockDescriptor_A<Problem>()),
Policy::template MakeLdsStoreBlockDescriptor_A<Problem>().get_lengths(),
{0, 0});
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), a_gtile_window);
async_load_tile(k_lds_store(LdsSeq.at(number<0>{})));
for(index_t i_0 = 0; i_0 < loops_0; i_0++) {}
}
......@@ -351,7 +353,7 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
buffer_load_fence_raw(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
......@@ -403,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
buffer_load_fence_raw(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
......@@ -428,7 +430,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access());
async_load_fence_raw(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
......@@ -450,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
async_load_fence_raw();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
......
......@@ -14,7 +14,6 @@ namespace ck_tile {
struct FusedMoePipelinePolicy
{
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{
// TODO:
......@@ -22,7 +21,7 @@ struct FusedMoePipelinePolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
{
// using async
static constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
......@@ -32,54 +31,27 @@ struct FusedMoePipelinePolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentG()
{
static constexpr index_t copy_bytes = [&]() {
if constexpr(Problem::Traits::GateUpPreShuffled)
{
return 4 * 4;
}
else
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
{
return 4 * GetAsyncCopyDwords();
}
}();
static constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentU()
{
static constexpr index_t copy_bytes = [&]() {
if constexpr(Problem::Traits::GateUpPreShuffled)
{
return 4 * 4;
}
else
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_U()
{
return 4 * GetAsyncCopyDwords();
}
}();
static constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::UDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentD()
{
static constexpr index_t copy_bytes = [&]() {
if constexpr(Problem::Traits::DownPreShuffled)
{
return 4 * 4;
}
else
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
{
return 4 * GetAsyncCopyDwords();
}
}();
static constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
......@@ -93,29 +65,11 @@ struct FusedMoePipelinePolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackA()
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A()
{
return GetSmemKPack<typename Problem::ADataType>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackG()
{
return GetSmemKPack<typename Problem::GDataType>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackU()
{
return GetSmemKPack<typename Problem::UDataType>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackD()
{
return GetSmemKPack<typename Problem::DDataType>();
}
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK()
{
......@@ -206,102 +160,58 @@ struct FusedMoePipelinePolicy
// Caution: this will require global memory pre-shuffled to follow the mfma layout
// to maximize the L1/L2 channel while skip LDS
/*
(b) n0 n1 n2 k0 k1 k2
klanes
|
nr 4 kr 4 16 8
(b) n0 n1 k0 k1 n2 k2 -> kthreads
| |
V V
waves nlanes
klanes
|
nr kr 4 4 16 8
(b) n0 k0 n1 k1 n2 k2 -> kthreads
| |
V V
waves nlanes
*/
template <typename BlockTile, typename BlockWarps, typename WarpGemm, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled_NxK()
template <index_t NPerBlock,
index_t KPerBlock,
index_t WavesPerBlock_N,
index_t WavesPerBlock_K,
typename WarpGemm,
index_t Alignment,
FusedMoePermuteStyle PermuteStyle = FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled()
{
static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0);
static_assert(BlockWarps{}.at(number<0>{}) == 1 && BlockWarps{}.at(number<2>{}) == 1);
static constexpr index_t NumWarps =
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
constexpr index_t NPerBlock = BlockTile{}.at(number<1>{});
constexpr index_t KPerBlock = BlockTile{}.at(number<2>{});
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = NumWarps;
if constexpr(PermuteStyle == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv)
{
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
if constexpr(get_warp_size() < K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
// NOTE: no swap, but hard to avoid LDS bank conflict
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
// NOTE: swapped for LDS load bank conflict free
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_lan, M_wav>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
sequence<1>, // 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>,
// Nr_p, Kr_p Kw Nw
tuple<sequence<1, 2>, sequence<3, 3>>,
tuple<sequence<1, 1>, sequence<0, 1>>,
// Nr_y Kr_y Kv
sequence<1, 2, 3>,
sequence<0, 0, 2>>{});
// clang-format on
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeAGlobalTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
{
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentA<Problem>();
constexpr index_t Alignment = GetAlignment_A<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kMPerBlock,
kKPerBlock,
NumWarps,
......@@ -309,42 +219,75 @@ struct FusedMoePipelinePolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGGlobalTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
if constexpr(PermuteStype == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv)
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentG<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kNPerBlock,
constexpr index_t WavesPerBlock_N = Problem::Gemm0BlockWarps {}
::at(number<1>{});
constexpr index_t WavesPerBlock_K = Problem::Gemm0BlockWarps {}
::at(number<2>{});
using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
constexpr index_t Alignment = GetAlignment_G<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment>();
WavesPerBlock_N,
WavesPerBlock_K,
WarpGemm,
Alignment,
PermuteStype>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeUGlobalTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_U()
{
constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
if constexpr(PermuteStype == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv)
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentU<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kNPerBlock,
constexpr index_t WavesPerBlock_N = Problem::Gemm0BlockWarps {}
::at(number<1>{});
constexpr index_t WavesPerBlock_K = Problem::Gemm0BlockWarps {}
::at(number<2>{});
using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
constexpr index_t Alignment = GetAlignment_U<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment>();
WavesPerBlock_N,
WavesPerBlock_K,
WarpGemm,
Alignment,
PermuteStype>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeDGlobalTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{
constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
if constexpr(PermuteStype == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv)
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentD<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kNPerBlock,
constexpr index_t WavesPerBlock_N = Problem::Gemm1BlockWarps {}
::at(number<1>{});
constexpr index_t WavesPerBlock_K = Problem::Gemm1BlockWarps {}
::at(number<2>{});
using WarpGemm = remove_cvref_t<GetWarpGemm1<Problem>()>;
constexpr index_t Alignment = GetAlignment_D<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment>();
WavesPerBlock_N,
WavesPerBlock_K,
WarpGemm,
Alignment,
PermuteStype>();
}
}
template <index_t MPerBlock,
......@@ -359,9 +302,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kBlockSize = ck_tile::get_warp_size() * NumWarps; // Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector =
Alignment; // GetAlignmentK<Problem>(); // this is for global load
constexpr index_t KVector = Alignment; // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(warpSize * KVector >= KPerBlock && warpSize * KVector % KPerBlock == 0);
......@@ -402,77 +343,188 @@ struct FusedMoePipelinePolicy
return lds_block_desc;
}
template <index_t MPerBlock,
index_t KPerBlock,
index_t NumWarps,
index_t KPack,
index_t Alignement,
index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeSmemStoreBlockDescriptor_SimpleMxK_Async(number<IBuf> = number<0>{})
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor_A()
{
constexpr index_t kBlockSize = ck_tile::get_warp_size() * NumWarps; // Problem::kBlockSize;
// A async->LDS
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
// constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
// constexpr index_t Alignement = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert(warpSize * Alignement >= KPerBlock && warpSize * Alignement % KPerBlock == 0);
constexpr index_t LanesPerK =
KPerBlock / Alignement; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
warpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = MPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == MPerBlock * KPerBlock / (BlockSize * Alignement));
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
static_assert(kKPerBlock % kVector == 0);
constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K
if constexpr(LanesPerK > warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = kMPerBlock / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1
number<warpSize * KVector + kPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = kMPerBlock / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<Alignement>{}), // k1
make_tuple(number<NumWarps*(warpSize * Alignement + kPad)>{},
number<KPerBlock>{},
number<warpSize * Alignement + kPad>{},
number<Alignement>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<Alignement>{},
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<kKPerBlock>{}, // m1
number<warpSize * KVector + kPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<Alignement>{}))),
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return k_lds_block_desc_issues_warps_lanes;
return lds_block_desc_issues_warps_lanes;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeASmemLoadTileDistribution()
{
CK_TILE_HOST_DEVICE static constexpr auto MakeSmemLoadTileDistribution_A()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentA<Problem>();
constexpr index_t KPack = GetSmemKPackA<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchA;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kMPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps
static_assert(kKPerBlock % kVector == 0);
constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K
if constexpr(LanesPerK > warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = kMPerBlock / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1
number<warpSize * KVector + kPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
make_merge_transform(make_tuple(
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = kMPerBlock / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<kKPerBlock>{}, // m1
number<warpSize * KVector + kPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeASmemStoreTileDistribution()
......@@ -480,8 +532,8 @@ struct FusedMoePipelinePolicy
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentA<Problem>();
constexpr index_t KPack = GetSmemKPackA<Problem>();
constexpr index_t Alignment = GetAlignment_A<Problem>();
constexpr index_t KPack = GetSmemKPack_A<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchA;
return MakeSmemStoreBlockDescriptor_SimpleMxK_Async<kMperBlock,
......@@ -492,13 +544,14 @@ struct FusedMoePipelinePolicy
Alignment>();
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentG<Problem>();
constexpr index_t Alignment = GetAlignment_G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
......@@ -515,7 +568,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentG<Problem>();
constexpr index_t Alignment = GetAlignment_G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
......@@ -533,7 +586,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentU<Problem>();
constexpr index_t Alignment = GetAlignment_U<Problem>();
constexpr index_t KPack = GetSmemKPackU<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchU;
......@@ -551,7 +604,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentD<Problem>();
constexpr index_t Alignment = GetAlignment_D<Problem>();
constexpr index_t KPack = GetSmemKPackD<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchD;
......@@ -562,7 +615,32 @@ struct FusedMoePipelinePolicy
KPack,
NumPrefetch>();
}
#endif
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
{
return WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{
return WarpGemmMfmaDispatcher<typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetGemm0()
{
......@@ -628,5 +706,6 @@ struct FusedMoePipelinePolicy
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
#endif
};
} // namespace ck_tile
......@@ -12,21 +12,22 @@
namespace ck_tile {
template <bool GateUpPreShuffled_ = false,
bool DownPreShuffled_ = false,
index_t NumPrefetchA_ = 2,
index_t NumPrefetchG_ = 2,
index_t NumPrefetchU_ = 2,
index_t NumPrefetchD_ = 2,
enum class FusedMoePermuteStyle
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
no_permute = 999,
};
template <bool DownPreShuffled_ = false,
FusedMoePermuteStyle PermuteStyle_ = FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct FusedMoeTraits
{
static constexpr bool GateUpPreShuffled = GateUpPreShuffled_;
static constexpr bool DownPreShuffled = DownPreShuffled_;
static constexpr index_t NumPrefetchA = NumPrefetchA_;
static constexpr index_t NumPrefetchG = NumPrefetchG_;
static constexpr index_t NumPrefetchU = NumPrefetchU_;
static constexpr index_t NumPrefetchD = NumPrefetchD_;
static constexpr FusedMoePermuteStyle PermuteStyle = PermuteStyle_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment