Unverified Commit ec959387 authored by rocking's avatar rocking Committed by GitHub
Browse files

Merge branch 'develop' into ck_tile/fmha_receipt_aiter

parents c1e2fef7 0e5e29c4
......@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
......@@ -14,6 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......
......@@ -15,6 +15,10 @@ namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
......@@ -28,7 +32,7 @@ namespace ck_tile {
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......@@ -55,6 +59,34 @@ namespace ck_tile {
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens)
// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5]
// num_tokens_post_padded_ptr : [24]
//
// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case)
// and modify the output expert-ID, because we will only have enbaled expert on specific GPU.
// we call expert input to this kernel as "global expert id", output as "local expert id"
//
// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id")
// num_tokens_post_padded_ptr : [20]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
......@@ -67,10 +99,80 @@ namespace ck_tile {
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_)
{
/* num_experts + 1
* +--------------------------------------+
* | |
* | |
* | | * -> sub-tokens
* | |
* | |
* +--------------------------------------+
* | | 2 -> cumsum buffer
* +--------------------------------------+
*
*/
int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here
int smem_rows = [&](){
index_t target_occupancy_ = 2;
constexpr index_t total_ = 65536 / sizeof(int);
constexpr index_t sub_unroll = 8;
constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
// at lease 2 lines, one for sub_token unroll, one for cumsum
// should be enough
if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) {
if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols))
throw std::runtime_error("too many num_experts, can't allocate smem");
target_occupancy_ = 1;
}
int r = total_ / target_occupancy_ / smem_cols;
// round to sub_unroll multipl
int r_for_sub_token = r - cumsum_bufs;
r_for_sub_token = min(r_for_sub_token, num_tokens_);
r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll;
r_for_sub_token = max(r_for_sub_token, 1);
if(r_for_sub_token > 1)
{
int r_unroll_ = r_for_sub_token / sub_unroll;
// round to 1x/2x/4x/8x number of sub_unroll
int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29
int mask_ = (1 << (31 - clz_)) - 1;
mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most
mask_ = ~mask_;
//printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout);
r_for_sub_token = (r_unroll_ & mask_) * sub_unroll;
}
// final check
if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) {
throw std::runtime_error("can't run this kernel, request LDS over size");
}
return r_for_sub_token + cumsum_bufs;
}();
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
return ck_tile::make_tuple(smem_rows, smem_cols);
}
struct MoeSortingHostArgs
{
const void* p_topk_ids; // [token, topk]
const void* p_weights; // [token, topk]
const void* p_local_expert_mask;
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
......@@ -101,6 +203,7 @@ struct MoeSortingKernel
{
const void* p_topk_ids;
const void* p_weights;
const void* p_local_expert_mask;
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
......@@ -111,8 +214,11 @@ struct MoeSortingKernel
index_t moe_buf_bytes;
index_t tokens_per_thread;
index_t smem_rows;
mdiv unit_size_mdiv;
mdiv topk_mdiv;
mdiv expert_mdiv;
// mdiv sub_tokens_mdiv;
};
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
......@@ -123,15 +229,25 @@ struct MoeSortingKernel
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{
#if MOE_SORTING_USE_EX_KERNEL
(void)h;
return dim3(256);
#else
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
#endif
}
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{
#if MOE_SORTING_USE_EX_KERNEL
auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
return smem_rows * smem_cols * sizeof(int);
#else
const auto blocks = BlockSize(h);
// usually num_experts is power of 2, we pad 1 dword here for the row-size
return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
#endif
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
......@@ -139,6 +255,7 @@ struct MoeSortingKernel
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights;
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
......@@ -152,10 +269,18 @@ struct MoeSortingKernel
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.smem_rows = [&](){
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
(void) c_;
return r_;
}();
k.expert_mdiv = mdiv{static_cast<uint32_t>(h.num_experts)};
// k.sub_tokens_mdiv = mdiv{static_cast<uint32_t>(k.smem_rows - 1)};
return k;
}
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// NOTE: wave_size need at least be 16!! dpp 16 is one row
template <typename data_t, int wave_size>
__device__ inline void wave_cumsum(data_t& thread_data) const
{
......@@ -196,6 +321,40 @@ struct MoeSortingKernel
bank_mask,
bound_ctrl))); // row_shr:4
}
if constexpr(wave_size == 8) {
// wave-size=8 need one extra shift
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x118,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
#if 0
constexpr int bank_mask_0_7 = 0b1100;
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
__builtin_amdgcn_update_dpp(0, /* old value */
__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask_0_7,
bound_ctrl))// row_newbcast:7
);
#else
data_t xxx =__builtin_bit_cast(data_t,
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask,
bound_ctrl)); // row_newbcast:7
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
thread_data = thread_data - yyy;
#endif
}
if constexpr(wave_size > 8)
{
thread_data =
......@@ -224,6 +383,36 @@ struct MoeSortingKernel
}
}
// reduce single pixel within a wave
template <typename T, typename F, index_t wave_size_ = warpSize>
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
// constexpr int reduce_stage = 6; // 1<<6=64
// clang-format off
constexpr int reduce_stage = [](){
if constexpr(wave_size_ == 2) return 1;
else if constexpr(wave_size_ == 4) return 2;
else if constexpr(wave_size_ == 8) return 3;
else if constexpr(wave_size_ == 16) return 4;
else if constexpr(wave_size_ == 32) return 5;
else if constexpr(wave_size_ == 64) return 6;
else return 0;
}();
// clang-format on
T v_local = local;
#pragma unroll reduce_stage
for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
{
int src_lane = __lane_id() ^ (1 << i_stage);
int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
T v_remote = bit_cast<T>(v_remote_tmp);
v_local = reduce_f(v_local, v_remote);
}
return v_local;
}
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{
return row * total_col + col;
......@@ -257,37 +446,37 @@ struct MoeSortingKernel
index_t* shared_mem = reinterpret_cast<index_t*>(smem);
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1)
index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1); // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i)
{
tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0;
tokens_cnts[calc_index(num_experts + 1, tid + 1, i)] = 0;
}
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])];
++tokens_cnts[calc_index(num_experts + 1, tid + 1, topk_id[i])];
}
__syncthreads();
#if 1
if(tid < num_experts)
{
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
index_t local_c[8];
index_t prev_c = 0;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i+= 8)
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i += 8)
{
local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)];
local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)];
local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)];
local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)];
local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)];
local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)];
local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)];
local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)];
local_c[0] = tokens_cnts[calc_index(num_experts + 1, i + 0, tid)];
local_c[1] = tokens_cnts[calc_index(num_experts + 1, i + 1, tid)];
local_c[2] = tokens_cnts[calc_index(num_experts + 1, i + 2, tid)];
local_c[3] = tokens_cnts[calc_index(num_experts + 1, i + 3, tid)];
local_c[4] = tokens_cnts[calc_index(num_experts + 1, i + 4, tid)];
local_c[5] = tokens_cnts[calc_index(num_experts + 1, i + 5, tid)];
local_c[6] = tokens_cnts[calc_index(num_experts + 1, i + 6, tid)];
local_c[7] = tokens_cnts[calc_index(num_experts + 1, i + 7, tid)];
local_c[0] += prev_c;
local_c[1] += local_c[0];
......@@ -299,51 +488,57 @@ struct MoeSortingKernel
local_c[7] += local_c[6];
prev_c = local_c[7];
tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0];
tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1];
tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2];
tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3];
tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4];
tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5];
tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6];
tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7];
tokens_cnts[calc_index(num_experts + 1, i + 0, tid)] = local_c[0];
tokens_cnts[calc_index(num_experts + 1, i + 1, tid)] = local_c[1];
tokens_cnts[calc_index(num_experts + 1, i + 2, tid)] = local_c[2];
tokens_cnts[calc_index(num_experts + 1, i + 3, tid)] = local_c[3];
tokens_cnts[calc_index(num_experts + 1, i + 4, tid)] = local_c[4];
tokens_cnts[calc_index(num_experts + 1, i + 5, tid)] = local_c[5];
tokens_cnts[calc_index(num_experts + 1, i + 6, tid)] = local_c[6];
tokens_cnts[calc_index(num_experts + 1, i + 7, tid)] = local_c[7];
}
}
#else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future
// heuristic
{
if(tid < num_experts)
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
for(int i = 0; i < num_experts; i+=8) {
tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
for(int i = 0; i < num_experts; i += 8)
{
index_t local_c[8];
#pragma unroll
for(int j = 0; j < 8; j++) {
local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)];
#pragma unroll
for(int j = 0; j < 8; j++)
{
local_c[j] = tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)];
}
#pragma unroll
for(int j = 0; j < 8; j++) {
#pragma unroll
for(int j = 0; j < 8; j++)
{
wave_cumsum<int, 64>(local_c[j]);
}
#pragma unroll
for(int j = 0; j < 8; j++) {
tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j];
#pragma unroll
for(int j = 0; j < 8; j++)
{
tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j];
}
}
}
#endif
__syncthreads();
if constexpr (Problem::ExpertTile == 0) {
if constexpr(Problem::ExpertTile == 0)
{
if(tid == 0)
{
cumsum[0] = 0;
for(int i = 1; i <= num_experts; ++i)
{
auto current_units = [&]() {
index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] +
unit_size_mdiv.divisor - 1;
index_t x_ = tokens_cnts[calc_index(num_experts + 1, blockDim.x, i - 1)] +
unit_size_mdiv.divisor - 1;
index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor;
}();
......@@ -351,20 +546,24 @@ struct MoeSortingKernel
}
*p_total_tokens_post_pad = cumsum[num_experts];
}
} else {
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >=
// expert) for simplicity, not check experts here.
int local_cnt = tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
int local_cumsum = padded_tokens_per_expert;
int local_cumsum = padded_tokens_per_expert;
wave_cumsum<int, 64>(local_cumsum);
if(tid == (num_experts - 1)) {
cumsum[0] = 0;
if(tid == (num_experts - 1))
{
cumsum[0] = 0;
*p_total_tokens_post_pad = local_cumsum;
}
if(tid < num_experts) {
if(tid < num_experts)
{
cumsum[tid + 1] = local_cumsum;
}
}
......@@ -373,7 +572,7 @@ struct MoeSortingKernel
if(tid < num_experts)
{
int e_start = cumsum[tid];
int e_end = cumsum[tid + 1];
int e_end = cumsum[tid + 1];
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
......@@ -383,8 +582,8 @@ struct MoeSortingKernel
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
index_t expert_id = topk_id[i];
index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)];
index_t expert_id = topk_id[i];
index_t local_cnt = tokens_cnts[calc_index(num_experts + 1, tid, expert_id)];
index_t rank_post_pad = local_cnt + cumsum[expert_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t curr_token_id, curr_topk_id;
......@@ -393,16 +592,17 @@ struct MoeSortingKernel
#else
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
#endif
p_sorted_weights[rank_post_pad] = weights[i];
tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1;
p_sorted_weights[rank_post_pad] = weights[i];
tokens_cnts[calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1;
}
if constexpr (Problem::ExpertTile == 0) {
if constexpr(Problem::ExpertTile == 0)
{
const index_t prefill_token = topk_mdiv.div(numel);
if(tid < num_experts)
{
index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
cumsum[tid] + tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
index_t expert_end = cumsum[tid + 1];
while(expert_offset < expert_end)
{
......@@ -417,16 +617,19 @@ struct MoeSortingKernel
}
}
}
else {
else
{
const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
{
index_t eid = tid / experts_per_wave;
index_t expert_offset =
cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave;
index_t eid = tid / experts_per_wave;
index_t expert_offset = cumsum[eid] +
tokens_cnts[calc_index(num_experts + 1, blockDim.x, eid)] +
tid % experts_per_wave;
index_t expert_end = cumsum[eid + 1];
if(eid < num_experts) {
if(eid < num_experts)
{
while(expert_offset < expert_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
......@@ -436,10 +639,363 @@ struct MoeSortingKernel
p_sorted_token_ids[expert_offset] = prefill_token;
#endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset+=experts_per_wave;
expert_offset += experts_per_wave;
}
}
}
}
}
// only support index_t, and single pixel access
struct simple_smem_indexer
{
index_t* smem;
index_t row_stride;
// this is 2D
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_, index_t row_stride_)
: smem(smem_), row_stride(row_stride_)
{
}
CK_TILE_DEVICE const index_t& operator()(index_t i_row, index_t i_col) const
{
return smem[i_row * row_stride + i_col];
}
CK_TILE_DEVICE index_t& operator()(index_t i_row, index_t i_col)
{
return smem[i_row * row_stride + i_col];
}
// this is 1D or linear
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_) : smem(smem_), row_stride(0) {}
CK_TILE_DEVICE const index_t& operator()(index_t idx) const { return smem[idx]; }
CK_TILE_DEVICE index_t& operator()(index_t idx) { return smem[idx]; }
};
CK_TILE_DEVICE void
moe_align_block_size_kernel_ex(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights,
const IndexType* __restrict__ local_expert_mask,
index_t* p_sorted_token_ids,
WeightType* p_sorted_weights,
index_t* p_sorted_expert_ids,
index_t* p_total_tokens_post_pad,
const index_t num_experts,
const index_t tokens,
const mdiv unit_size_mdiv,
const mdiv topk_mdiv,
const mdiv expert_mdiv,
const index_t smem_rows,
void* smem) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
const index_t lid = __lane_id();
constexpr index_t block_size = 256; // blockDim.x;
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
const index_t topk = topk_mdiv.divisor;
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
const index_t smem_cols = num_experts + 1;
simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + 0};
simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + smem_cols};
simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem) + 2 * smem_cols,
smem_cols};
// #pragma unroll 8
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
{
uint32_t curr_token_id, curr_expert_id;
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
smem_tokens(curr_token_id, curr_expert_id) = 0;
}
__syncthreads();
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
{
// NOTE: below for loop can't have barrier inside!!
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
int i_t = i_token + curr_token_id;
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
if constexpr(Problem::SubTokenOneShot)
smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
else
smem_tokens(curr_token_id, eid)++;
}
__builtin_amdgcn_s_waitcnt(0xc07f);
}
__syncthreads(); // make sure different i_token iteration not overlap by different wave
}
// counting
if(tid == 0)
{
smem_cumsum(0) = 0;
// smem_cumdup(0) = 0;
}
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm)
{
index_t local_c[Problem::SubTokenTile];
index_t cnt = 0;
for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile)
{
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e);
if constexpr(Problem::SubTokenOneShot)
{
local_c[j] = local_c[j] != 0 ? 1 : 0;
}
}
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
cnt += wave_reduce(local_c[j], f_sum, number<8>{});
}
}
if(lane_group_os == 0)
smem_cumsum(i_e + 1) = cnt;
}
}
if constexpr(Problem::LocalExpertMasking)
{
smem_cumdup(0) = 0;
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
// reuse this buffer
smem_cumdup(i_e + 1) = local_expert_mask[i_e];
}
}
__syncthreads();
{
if(wid == 0)
{
// NOTE: under this block can never use __syncthreads!
int i_e_ = 0;
int local_cumsum_ = 0;
for(; i_e_ < num_experts; i_e_ += warpSize)
{
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
int local_cnt = smem_cumsum(i_e_ + lid + 1);
int blocks_pers_expert =
unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int pre_cumsum_masking = [&]() {
if constexpr(Problem::LocalExpertMasking)
return smem_cumdup(lid == 0 ? i_e_ : 0);
else
return 0; // not used
}();
int local_masking = [&]() {
if constexpr(Problem::LocalExpertMasking)
return smem_cumdup(i_e_ + lid + 1);
else
return 0; // not used
}();
int padded_tokens_per_expert = [&]() {
int x_ = [&]() {
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
// if local_cnt is zero, blocks_pers_expert will be zero
// this is what we want to achieve
return blocks_pers_expert * unit_size_mdiv.divisor;
}
else
{
return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
}
}();
if constexpr(Problem::LocalExpertMasking)
{
return local_masking ? x_ : 0;
}
else
return x_;
}();
local_cumsum_ = padded_tokens_per_expert;
local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum<int, warpSize>(local_cumsum_);
if((i_e_ + lid) < num_experts)
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
if constexpr(Problem::LocalExpertMasking)
{
local_masking += pre_cumsum_masking;
wave_cumsum<int, warpSize>(local_masking);
if((i_e_ + lid) < num_experts)
smem_cumdup(i_e_ + lid + 1) = local_masking;
}
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
// for above write however __syncthreads will cause barrier with waves other
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
}
if((lid + i_e_ - warpSize) == (num_experts - 1))
{
*p_total_tokens_post_pad = local_cumsum_;
}
}
__syncthreads();
}
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
int e_start = smem_cumsum(i_e);
int e_end = smem_cumsum(i_e + 1);
int expert_id = [&]() {
if constexpr(Problem::LocalExpertMasking)
{
// local expert id from cumsum
return smem_cumdup(i_e);
}
else
return i_e;
}();
smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
if constexpr(Problem::LocalExpertMasking)
{
if(local_expert_mask[i_e] == 0)
continue;
}
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
}
}
smem_cumdup(num_experts) = smem_cumsum(num_experts);
// fill the p_sorted_token_ids/p_sorted_weights
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
{
if constexpr(!Problem::SubTokenOneShot)
{
// clear every time
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
{
uint32_t curr_token_id, curr_expert_id;
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
smem_tokens(curr_token_id, curr_expert_id) = 0;
}
__syncthreads();
// load again
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
uint32_t curr_token_id_, curr_topk_id_;
topk_mdiv.divmod(i, curr_token_id_, curr_topk_id_);
int curr_token_id = static_cast<int>(curr_token_id_);
int curr_topk_id = static_cast<int>(curr_topk_id_);
int i_t = i_token + curr_token_id;
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1
}
}
__syncthreads();
}
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm)
{
if constexpr(Problem::LocalExpertMasking)
{
if(local_expert_mask[eid] == 0)
continue;
}
int position = smem_cumsum(eid);
for(int i_sub_token = lane_group_os; i_sub_token < sub_tokens;
i_sub_token += lane_group_sz)
{
auto x = smem_tokens(i_sub_token, eid);
int local_cnt_cache = x != 0 ? 1 : 0;
int local_cnt = local_cnt_cache;
wave_cumsum<int, lane_group_sz>(local_cnt);
if(x != 0)
{
// now x is topk value
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[position + local_cnt - 1] =
MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1);
#else
p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token;
#endif
p_sorted_weights[position + local_cnt - 1] =
weights[(i_token + i_sub_token) * topk + x - 1];
}
int remote_cnt = __builtin_amdgcn_ds_bpermute(
(lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt);
position += remote_cnt;
}
smem_cumsum(eid) = position;
}
}
}
__syncthreads();
}
// add the skip number
for(int eid = tid; eid < num_experts; eid += block_size)
{
int e_start = smem_cumsum(eid);
int e_end = smem_cumdup(eid + 1);
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
while(e_start < e_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[e_start] = MOE_SORTING_MOCK_ID(tokens, topk);
#else
p_sorted_token_ids[e_start] = tokens;
#endif
p_sorted_weights[e_start] = static_cast<WeightType>(0.0);
e_start++;
}
}
}
......@@ -456,6 +1012,24 @@ struct MoeSortingKernel
}
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
extern __shared__ char smem[];
#if MOE_SORTING_USE_EX_KERNEL
(void)numel;
return moe_align_block_size_kernel_ex(
static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<const IndexType*>(kargs.p_local_expert_mask),
static_cast<IndexType*>(kargs.p_sorted_token_ids),
static_cast<WeightType*>(kargs.p_sorted_weights),
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
kargs.num_experts,
kargs.tokens,
kargs.unit_size_mdiv,
kargs.topk_mdiv,
kargs.expert_mdiv,
kargs.smem_rows,
smem);
#else
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.p_sorted_token_ids),
......@@ -468,6 +1042,7 @@ struct MoeSortingKernel
kargs.unit_size_mdiv,
kargs.topk_mdiv,
smem);
#endif
}
};
......
......@@ -25,4 +25,28 @@ struct MoeSortingProblem
InternalLoadUnroll_; // TODO: need better design(like tile size)
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};
template <typename IndexType_,
typename WeightType_,
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
bool SubTokenOneShot_, // if we only loop over once or not
bool LocalExpertMasking_, // used in EP case
bool SkipExpertsWithZeroTokens_ = true,
index_t ExpertTile_ = 0>
struct MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};
} // namespace ck_tile
......@@ -29,6 +29,8 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
......@@ -46,3 +48,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -14,24 +14,54 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
private:
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t KPack = WarpGemm::kKPerThread;
};
public:
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using Traits = GemmTraits_<Problem, Policy>;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
......@@ -43,7 +73,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
......@@ -58,7 +88,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
......@@ -73,7 +103,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
......@@ -112,13 +142,13 @@ struct BlockGemmARegBRegCRegV1
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
......@@ -157,7 +187,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
......@@ -180,7 +210,7 @@ struct BlockGemmARegBRegCRegV1
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
......
......@@ -79,8 +79,11 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KPerBlock / WarpGemm::kK * KPack;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static constexpr index_t KPack = 8;
static constexpr index_t KPerThread = KIterPerWarp * KPack;
static constexpr index_t KRepeat = KPerThread / KPack;
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
......@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
struct BatchedGemmKernelArgs : GemmKernelArgs
{
index_t batch_stride_A;
......@@ -70,7 +84,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
__host__ static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
{
return TilePartitioner::GridSize(M, N, KBatch * batch_count);
return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
}
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
......@@ -101,14 +115,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
// options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
......@@ -128,7 +142,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1)
if(kargs.k_batch == 1)
{
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
......
......@@ -8,7 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
......@@ -69,18 +69,26 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
return TilePartitioner::GridSize(M, N, KBatch);
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
// clang-format on
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmKernelArgs
{
......@@ -93,7 +101,7 @@ struct GemmKernel
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t KBatch;
index_t k_batch;
};
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
......@@ -121,7 +129,7 @@ struct GemmKernel
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.KBatch * K1;
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
......@@ -142,13 +150,13 @@ struct GemmKernel
b_k_split_offset = k_id * KRead;
}
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1))
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = KRead;
}
else
{
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
}
}
......@@ -159,15 +167,12 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed) ||
!(std::is_same_v<CDataType, fp16_t> || std::is_same_v<CDataType, bf16_t>)))
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
{
if(kargs.KBatch != 1)
if(kargs.k_batch != 1)
{
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
return false;
}
}
......@@ -176,10 +181,14 @@ struct GemmKernel
{
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{
std::cerr << "Can't support K that is not a multiple of KPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.K % GemmPipeline::VectorSizeA != 0)
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
{
std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
return false;
}
}
......@@ -187,10 +196,14 @@ struct GemmKernel
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
std::cerr << "Can't support M that is not a multiple of MPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.M % GemmPipeline::VectorSizeA != 0)
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
{
std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
return false;
}
}
......@@ -199,10 +212,14 @@ struct GemmKernel
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
std::cerr << "Can't support N that is not a multiple of NPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.N % GemmPipeline::VectorSizeB != 0)
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
{
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
return false;
}
}
......@@ -210,10 +227,14 @@ struct GemmKernel
{
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{
std::cerr << "Can't support K that is not a multiple of KPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.K % GemmPipeline::VectorSizeB != 0)
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
{
std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
return false;
}
}
......@@ -222,10 +243,14 @@ struct GemmKernel
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
std::cerr << "Can't support N that is not a multiple of NPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.N % GemmPipeline::VectorSizeC != 0)
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false;
}
}
......@@ -233,10 +258,14 @@ struct GemmKernel
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
std::cerr << "Can't support M that is not a multiple of MPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.M % GemmPipeline::VectorSizeC != 0)
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false;
}
}
......@@ -257,16 +286,16 @@ struct GemmKernel
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(1, kargs.stride_A),
number<1>{},
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
......@@ -276,9 +305,9 @@ struct GemmKernel
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(1, kargs.stride_B),
number<1>{},
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
......@@ -287,11 +316,12 @@ struct GemmKernel
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
......@@ -299,7 +329,7 @@ struct GemmKernel
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<1>{});
}
else
......@@ -331,9 +361,9 @@ struct GemmKernel
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
}();
......@@ -349,12 +379,13 @@ struct GemmKernel
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadN, false>{});
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
}();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
......@@ -380,20 +411,45 @@ struct GemmKernel
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
const auto& b_pad_view = views.at(I1);
const auto& b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
const auto& a_pad_view = views.at(I0);
const auto& b_pad_view = views.at(I1);
const auto& c_pad_view = views.at(I2);
auto c_block_window = make_tile_window(
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& b_block_window = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
......@@ -407,7 +463,9 @@ struct GemmKernel
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
......@@ -417,7 +475,7 @@ struct GemmKernel
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr,
void* smem_ptr_0,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
......@@ -435,28 +493,72 @@ struct GemmKernel
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
(GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
}
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr_0);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
......@@ -469,16 +571,53 @@ struct GemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr_1[GetSmemSize()];
if(kargs.KBatch == 1)
if(kargs.k_batch == 1)
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
RunGemm2LDS(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
}
else
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
}
}
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* @file
* GemmTilePartitioner allows customized mapping between a workgroup and the C-tile it computes.
*/
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/** @brief Struct representing 2D block index mapping into 3D output tile space. */
/**
* @brief Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
*
*/
template <typename BlockGemmShapeType>
struct GemmTile2DPartitioner
{
......@@ -17,21 +25,32 @@ struct GemmTile2DPartitioner
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
/** @brief Returns 3D grid size. */
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) noexcept(
noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept = delete;
CK_TILE_HOST_DEVICE GemmTile2DPartitioner([[maybe_unused]] index_t M,
[[maybe_unused]] index_t N) noexcept;
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
{
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
const index_t GridDimZ = batch_size;
return dim3(GridDimX, GridDimY, GridDimZ);
return dim3(GridDimX, GridDimY, 1);
}
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
{
return integer_divide_ceil(K, KPerBlock);
}
......@@ -42,8 +61,15 @@ struct GemmTile2DPartitioner
* @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes.
*/
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx,
index_t blockIdy) noexcept
/**
* @brief Calculate workgroup 2D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's X index.
* @param blockIdy WGP's Y index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept
-> const tuple<index_t, index_t>
{
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
......@@ -53,61 +79,71 @@ struct GemmTile2DPartitioner
};
/**
* @brief Struct representing 1D block index mapping into 2D output tile space.
* @brief Class providing 1D WGP index mapping into 2D output C-tile space.
*
* @tparam BlockGemmShape_ A class providing basic GEMM parameters. \link TileGemmShape
*/
template <typename BlockGemmShapeType>
template <typename BlockGemmShape_>
struct GemmTile1DPartitioner
{
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
/** @brief delete default ctr with no any object */
constexpr GemmTile1DPartitioner() noexcept = delete;
/** @brief constructs an object that does contain a N value. */
constexpr GemmTile1DPartitioner(index_t N) noexcept { N_ = N; }
CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept = delete;
/** @brief Returns 1D grid size. */
CK_TILE_HOST static constexpr auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
/**
* @brief Construct a new GemmTile1DPartitioner object.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
*/
CK_TILE_HOST_DEVICE GemmTile1DPartitioner([[maybe_unused]] index_t M, index_t N) noexcept
{
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
return dim3(GridDimX * GridDimY, 1, 1);
N_ = N;
}
/**
* @brief Returns the number of blocks in N.
* @param [in] N is dimension
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) noexcept -> index_t
CK_TILE_HOST static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
{
return integer_divide_ceil(N, NPerBlock);
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
return GridDimX * GridDimY;
}
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
{
return integer_divide_ceil(K, KPerBlock);
}
/**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x - block_start.
* */
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx) noexcept
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept
-> const tuple<index_t, index_t>
{
const index_t NBlock = GetNBlock(N_);
const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - (iM)*NBlock);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks);
return make_tuple(iM, iN);
}
......@@ -141,21 +177,176 @@ struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIn
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type.
*/
template <typename PartitionerFn,
typename = typename std::enable_if_t<HasFnOneArgImpl<PartitionerFn>{}>>
template <typename TilePartitioner,
typename = typename std::enable_if_t<HasFnOneArgImpl<TilePartitioner>{}>>
struct OffsettedTile1DPartitioner
{
/**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
* @param [in] block_start Workgroup offset.
* @param [in] M Gemm's M dimension.
* @param [in] N Gemm's N dimension.
* @return Returns a `tuple` [Im, In] with shifted index.
*/
[[nodiscard]] CK_TILE_DEVICE static constexpr auto GetOffsetedTileIndex(index_t block_start,
index_t N) noexcept
[[nodiscard]] CK_TILE_DEVICE static auto
GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept
-> const tuple<index_t, index_t>
{
const auto [iM, iN] = PartitionerFn(N).GetOutputTileIndex(blockIdx.x - block_start);
const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
return make_tuple(iM, iN);
}
};
/**
* @brief Class mapping 1D block index into 2D output tile space.
*
* @note It groups spatially workgroups in order to better utilize caches.
* It is using grouped Rows of column-vectors WGP pattern. It's optimized
* for gfx94x-like multiple-die chip.
*
* @tparam GroupNum - The number of big groups.
* @tparam M01 - The number of groups in M dim within spatially local WGPs,
*
*/
template <typename BlockGemmShapeType, index_t GroupNum, index_t M01>
struct GemmSpatiallyLocalTilePartitioner
{
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept = delete;
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner(index_t M_, index_t N_) noexcept
: M(M_), N(N_)
{
}
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return index_t A total number of workgroups.
*/
CK_TILE_HOST static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
{
const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
const index_t GridDimY = integer_divide_ceil(N, NPerBlock);
return GridDimX * GridDimY;
}
/**
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
{
return integer_divide_ceil(K, KPerBlock);
}
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param [in] block_1d_id WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept
-> const tuple<index_t, index_t>
{
const auto M0 = integer_divide_ceil(M, MPerBlock);
const auto N0 = integer_divide_ceil(N, NPerBlock);
if(M0 == 1)
{
return make_tuple(0, block_1d_id);
}
else if(N0 == 1)
{
return make_tuple(block_1d_id, 0);
}
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
else
{
const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
const auto group_id_y = block_1d_id / GroupNum;
const auto group_id_x = block_1d_id - group_id_y * GroupNum;
const auto remap_block_1d_id =
group_id_x <= big_group_num
? group_id_x * group_size + group_id_y
: group_id_x * group_size + big_group_num - group_id_x + group_id_y;
const index_t idx_M0 = remap_block_1d_id / N0;
const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0;
const index_t M0_tmp = M0 / M01;
const index_t M0_mod_M01 = M0 - M0_tmp * M01;
const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
const index_t idx_M00 = idx_M0 / M01;
const index_t idx_M01 = idx_M0 - idx_M00 * M01;
const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
const index_t N_out = idx_N0_M01_local / M01_adapt;
const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt;
return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out);
}
}
private:
index_t M;
index_t N;
};
} // namespace ck_tile
......@@ -50,7 +50,6 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using GemmKernelArgs = typename Base::GemmKernelArgs;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t KBatch = 1;
struct GemmTransKernelArg
{
......@@ -65,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
}
};
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
__host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::size_t
{
......@@ -78,8 +89,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
{
const auto dim3 = TilePartitioner::GridSize(it_desc.M, it_desc.N);
grid_size += dim3.x * dim3.y * 1;
const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
grid_size += local_grid_size * it_desc.k_batch;
}
return dim3(grid_size, 1, 1);
}
......@@ -107,8 +118,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const index_t stride_b = gemm_descs[i].stride_B;
const index_t stride_c = gemm_descs[i].stride_C;
const auto dim3 = TilePartitioner::GridSize(M, N);
const index_t grid_size_grp = dim3.x;
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
const index_t block_start = grid_size;
const index_t block_end = grid_size + grid_size_grp;
......@@ -124,7 +134,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
stride_a,
stride_b,
stride_c,
KBatch};
gemm_descs[i].k_batch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
......@@ -139,8 +149,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
{
const auto [iM, iN] =
OffsetTile1DPartitioner::GetOffsetedTileIndex(kargs.block_start, kargs.group_karg.N);
const auto [iM, iN] = OffsetTile1DPartitioner::GetOffsetedTileIndex(
kargs.block_start, kargs.group_karg.M, kargs.group_karg.N);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
......@@ -12,18 +13,23 @@ struct GemmPipelineAgBgCrImplBase
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
move_tile_window(dram_tile_window, dram_tile_window_step);
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
......@@ -35,20 +41,26 @@ struct GemmPipelineAgBgCrImplBase
store_tile(lds_tile_window, block_tile_tmp);
}
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
const SrcTileWindow& lds_tile_window) const
{
load_tile(dst_block_tile, lds_tile_window);
}
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
......@@ -60,19 +72,21 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view) const
{
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
make_tuple(YPerTile{}, XPerTile{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block_view,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
auto a_copy_lds_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto a_lds_gemm_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
......@@ -86,18 +100,22 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view) const
{
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
make_tuple(YPerTile{}, XPerTile{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// TODO: Do we really need those two tile windows???
// They're exactly same...
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
auto b_copy_lds_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_lds_gemm_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
......@@ -20,6 +24,8 @@ struct BaseGemmPipelineAgBgCrCompV3
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
......@@ -37,7 +43,7 @@ struct BaseGemmPipelineAgBgCrCompV3
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
......@@ -62,27 +68,85 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA;
static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ???
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AgBgCrCompV3", BlockSize,
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
......@@ -96,29 +160,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * VectorSizeA);
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * VectorSizeB);
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
......@@ -248,11 +318,22 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
......@@ -287,23 +368,51 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ABlockTile a_block_tile;
BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
......@@ -318,11 +427,31 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
block_sync_lds();
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV4
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
}
else
{
return TailNumber::Two;
}
}
};
/**
* @brief Compute optimized pipeline version 4
*
* This version introduces a dual LDS window mechanism using a ping-pong buffer approach
* for more efficient data handling from global memory. Unlike compute version 3, this method
* allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA
* matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy,
* thereby significantly reducing memory load times and enhancing overall performance.
*
* @note This version shows improved performance over Compute Version 3 with the same block tile.
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
* even when Compute Version 3's block size is twice that of Compute Version 4.
*/
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV4<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b;
constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
0x008, C_MFMA_Inst_Num / num_issue - 3, 0); // MFMA : 5
});
__builtin_amdgcn_sched_barrier(0);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register /////////////////
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// A register tile for global load
constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution();
constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution();
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr));
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// global prefetch 0
// global read 0
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
////////////// LDS desc, window & register /////////////////
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
auto a_copy_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto a_copy_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_copy_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_copy_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
}
// global read 1
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0;
ALdsTile a_block_tile1;
BLdsTile b_block_tile0;
BLdsTile b_block_tile1;
auto a_lds_ld_window0 =
make_tile_window_linear(a_lds_block0,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
auto a_lds_ld_window1 =
make_tile_window_linear(a_lds_block1,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
auto b_lds_ld_window0 =
make_tile_window_linear(b_lds_block0,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
auto b_lds_ld_window1 =
make_tile_window_linear(b_lds_block1,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
if(HasHotLoop)
{
// minus 2 because we have ping-pong double buffer.
index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2);
do
{
// ping
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window0, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// pong
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window1, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
iCounter -= 2;
} while(iCounter > 1);
}
// tail 3
if(TailNum == TailNumber::Three)
{
// 3
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
}
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
// 1
{
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
}
else
{
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
});
__builtin_amdgcn_sched_barrier(0);
}
// 1
{
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
__builtin_amdgcn_sched_barrier(0);
}
}
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
}
public:
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem_0,
p_smem_1);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
// UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct GemmPipelineAgBgCrCompV4DefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV4DefaultPolicy>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / KPack>{}, number<kMPerBlock>{}, number<KPack>{}),
make_tuple(number<kMPerBlock * KPack>{}, number<KPack>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock>{} / KPack, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / KPack>{}, number<kNPerBlock>{}, number<KPack>{}),
make_tuple(number<(kNPerBlock)*KPack>{}, number<KPack>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -7,6 +7,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
......@@ -20,6 +21,8 @@ struct BaseGemmPipelineAgBgCrMem
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
......@@ -88,7 +91,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
......@@ -113,19 +116,31 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA;
static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
// Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AgBgCrMe",
concat('x', MPerBlock, NPerBlock, KPerBlock),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
using Base::PrefetchStages;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
......@@ -133,8 +148,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
......@@ -165,11 +178,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
......@@ -213,25 +237,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
Base::GlobalPrefetch(
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
}
// Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
});
// main body
......@@ -247,19 +305,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds();
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(
a_shuffle_tmp,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(
b_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
}
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window);
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window);
b_copy_dram_window,
b_dram_tile_window_step);
});
i += PrefetchStages;
......@@ -275,12 +359,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds();
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
}
});
block_sync_lds();
......@@ -352,11 +456,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
......@@ -400,25 +515,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
Base::GlobalPrefetch(
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
}
// Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
});
// main body
......@@ -432,19 +580,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(
a_shuffle_tmp,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(
b_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
}
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window);
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window);
b_copy_dram_window,
b_dram_tile_window_step);
});
i += PrefetchStages;
......@@ -457,12 +631,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
}
});
block_sync_lds();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <sstream>
#include "ck_tile/core.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
......@@ -31,32 +32,36 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA;
static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
static constexpr index_t kLdsAlignmentInBytes = 16;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV1",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
......@@ -86,8 +91,9 @@ struct GemmPipelineAGmemBGmemCRegV1
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
kLdsAlignmentInBytes) *
kLdsAlignmentInBytes;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
......@@ -150,7 +156,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegBlockDescriptor<Problem>());
Policy::template MakeShuffledARegBlockDistribution<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
......@@ -164,7 +170,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
Policy::template MakeShuffledBRegBlockDistribution<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
......@@ -201,7 +207,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
Policy::template MakeShuffledBRegBlockDistribution<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -16,39 +16,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true;
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif 1
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
......@@ -58,7 +25,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
......@@ -127,87 +93,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
return Problem::VectorLoadSize / sizeof(ADataType);
return Problem::VectorLoadSize;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType);
return Problem::VectorLoadSize;
}
#elif 1
// fake XOR
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(ADataType);
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
a_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc_m_k;
}
// fake XOR
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc_n_k;
}
#endif
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
......@@ -273,7 +166,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
......@@ -394,7 +286,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
......@@ -442,11 +334,11 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>);
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
......@@ -489,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
......@@ -503,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
TransposeC>;
Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
......@@ -25,6 +26,15 @@ struct GemmPipelineAGmemBGmemCRegV2
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV2",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(
......@@ -36,8 +46,6 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment