Unverified Commit cabbacb6 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #260 from ROCm/merge_from_public

Merge from public
parents 5e93fa9e f3ff55b6
...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs ...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // output row_stride
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
...@@ -58,14 +59,21 @@ struct Smoothquant ...@@ -58,14 +59,21 @@ struct Smoothquant
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // out row_stride
}; };
using Hargs = SmoothquantHostArgs; using Hargs = SmoothquantHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{ return Kargs{hargs.p_x,
hargs.p_x, hargs.p_xscale, hargs.p_yscale, hargs.p_qy, hargs.m, hargs.n, hargs.stride}; hargs.p_xscale,
hargs.p_yscale,
hargs.p_qy,
hargs.m,
hargs.n,
hargs.x_stride,
hargs.y_stride};
} }
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
...@@ -116,7 +124,7 @@ struct Smoothquant ...@@ -116,7 +124,7 @@ struct Smoothquant
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.x_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -157,7 +165,7 @@ struct Smoothquant ...@@ -157,7 +165,7 @@ struct Smoothquant
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<QYDataType*>(kargs.p_qy), static_cast<QYDataType*>(kargs.p_qy),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.y_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
......
# reference
this folder contains reference implementation of a specific op. Note by including a specific header, you are including the implementation(expecially the gpu implementation) into your source code, and compile that kernel into the fatbin, hence may increase your kernel obj code length. Usually the header starts with `reference_` is a cpu reference implementation. The header starts with `naive_` contains a gpu implementation with a small launcher.
TODO: move `host/reference` under this folder
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <thread>
#include <string>
namespace ck_tile {
enum class naive_attention_layout_enum
{
BSHD, // [batch, seqlen, nhead, hdim]
BHSD, // [batch, nhead, seqlen, hdim]
BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
PHSD, // [pages, nhead, page_size, hdim]
// PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
PHDS, // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
};
// will used to specialize kernel variation
enum class naive_attention_variation_enum
{
FLASH_BATCHED = 0, // standard flash attention, or xformer/sdpa, used for training
FLASH_GROUPED,
DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache
};
// TODO: for simplicity, this will be used as host/device arg
struct naive_attention_fwd_args
{
void* q_ptr;
void* k_ptr;
void* v_ptr;
void* o_ptr;
void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a
// number, not cumsum)
void* page_table_ptr; // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
void* kvscale_ptr; // [nhead, 2(kv), hdim] used for kvcache dequant
float scale_s;
int hdim;
int hdim_v; // could be cross-attn, where V and Q/K hdim are different
int batch_q;
int batch_kv;
int batch_ratio_kv; // batch_q / batch_kv
int seqlen_q; // in decode case, this should be 1
int seqlen_kv; // if context_len_ptr is not nullptr, ignore this field
int nhead_q;
int nhead_kv;
int nhead_ratio_kv; // nhead_q / nhead_kv
int page_size; // if paged, the seqlen-kv per each block
int max_pages_per_seq;
};
// this is trait for host API
struct naive_attention_fwd_traits
{
std::string q_type;
std::string k_type;
std::string v_type;
std::string o_type;
std::string q_layout;
std::string k_layout;
std::string v_layout;
std::string o_layout;
int variation; // sync with naive_attention_variation_enum
};
// this is trait for kernel template
template <naive_attention_variation_enum variation_>
struct naive_attention_fwd_kernel_traits
{
static constexpr naive_attention_variation_enum variation = variation_;
};
// for simplicity, please do not use const-reference type for the template type
template <typename QType,
typename KType,
typename VType,
typename OType,
typename AccType,
naive_attention_layout_enum QLayout,
naive_attention_layout_enum KLayout,
naive_attention_layout_enum VLayout,
naive_attention_layout_enum OLayout,
typename Traits>
struct naive_attention_fwd_kernel
{
static constexpr bool is_kvcache_i8 =
std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t> && sizeof(QType) != 1;
// kvcache-i8 will have per head scale, we apply this scale to Q/P matrix instead of original
// K/V matrix. This can speed up conversion since Q/P usually is fp16/bf16/fp32
static constexpr bool is_kvcache_i8_forward_quant = is_kvcache_i8;
// TODO: hardcode
using KVScaleType = float;
using SoftmaxType = float;
using PType = VType; // src A of gemm2, same type as V
using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>;
static constexpr int p_vec_elem = vector_traits<p_vec_type>::vector_size;
__host__ __device__ naive_attention_fwd_kernel() {}
template <typename T, naive_attention_layout_enum Layout>
struct addresser
{
int b, s, h, d; // batch, seqlen, nhead, hdim
T* base_ptr;
__device__ addresser(int b_, int s_, int h_, int d_, void* base_ptr_)
: b(b_), s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(base_ptr_))
{
}
// TODO: all the batch/nhead offset will accumulate to the base pointer
__device__ T* get_base(int i_b, int i_h)
{
if constexpr(Layout == naive_attention_layout_enum::BSHD)
return base_ptr + i_b * s * h * d + i_h * d;
else if constexpr(Layout == naive_attention_layout_enum::BHSD)
return base_ptr + i_b * s * h * d + i_h * s * d;
}
__device__ int get_offset(int i_s, int i_d)
{
if constexpr(Layout == naive_attention_layout_enum::BSHD)
return i_s * h * d + i_d;
else if constexpr(Layout == naive_attention_layout_enum::BHSD)
return i_s * d + i_d;
}
// below set of API will directly use pointer inside this struct
__device__ void init(int i_b, int i_h) { base_ptr = get_base(i_b, i_h); }
__device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
__device__ void store(T value, int i_s, int i_d) { base_ptr[get_offset(i_s, i_d)] = value; }
};
template <typename T, naive_attention_layout_enum Layout>
struct page_addresser
{
int s, h, d; // page_size, nhead, hdim
static constexpr int x = 16 / sizeof(T); // pack 4 dword
T* base_ptr;
int* page_table_ptr; // TODO: page table always int
int i_h; // store current head
__device__ page_addresser(int s_, int h_, int d_, void* base_ptr_, void* pptr_)
: s(s_),
h(h_),
d(d_),
base_ptr(reinterpret_cast<T*>(base_ptr_)),
page_table_ptr(reinterpret_cast<int*>(pptr_))
{
}
__device__ int64_t get_phy_page_idx(int i_s)
{
// dynamic compute page idx is simple but slow
int page_idx = i_s / s;
int phy = page_table_ptr[page_idx];
return static_cast<int64_t>(phy);
}
__device__ int get_phy_page_offset(int i_s)
{
// dynamic compute page idx is simple but slow
return i_s % s;
}
__device__ int64_t get_offset(int i_s, int i_d)
{
int page_offset = get_phy_page_offset(i_s);
int64_t page_idx = get_phy_page_idx(i_s);
int64_t base_ = page_idx * h * s * d;
if constexpr(Layout == naive_attention_layout_enum::PHSD)
return static_cast<int64_t>(i_h * s * d + page_offset * d + i_d) + base_;
else if constexpr(Layout == naive_attention_layout_enum::PHDSX)
{
int d_r = i_d / x;
int d_x = i_d % x;
return static_cast<int64_t>(i_h * d * s + d_r * s * x + page_offset * x + d_x) +
base_;
}
else if constexpr(Layout == naive_attention_layout_enum::PHDS)
{
return static_cast<int64_t>(i_h * d * s + i_d * s + page_offset) + base_;
}
}
// below set of API will directly use pointer inside this struct
__device__ void init(int /*i_b*/, int i_h_) { i_h = i_h_; }
__device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
__device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
};
template <typename T>
struct kvscale_addresser
{
int h, d; // nhead, hdim
T* base_ptr;
__device__ kvscale_addresser(int h_, int d_, void* p_)
: h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
{
}
__device__ int get_offset(int i_h, int i_d, int i_kv /*0 or 1*/)
{
// [h, 2, d]
return i_h * 2 * d + i_kv * d + i_d;
}
__device__ T load(int i_h, int i_d, int i_kv)
{
return base_ptr[get_offset(i_h, i_d, i_kv)];
}
};
__device__ __host__ static constexpr int get_block_size() { return 256; }
// for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
// compute all hdim from q, compute WG_SIZE hdim from v
// 1) in prefill case, seqlen_q >= 1, seqlen_kv >= 1, batch_q=batch_kv
// 2) in decode case, seqlen_q = 1, batch_q is input num-tokens, batch_kv is 1
// 3) in paged-attn case, we still use 1 WG compute all the seqlen-kv for simplicity
// TODO: could support split-kv to validate intermediate logsum
__host__ static dim3 get_grid_size(naive_attention_fwd_args args)
{
constexpr int wg_size = get_block_size();
auto g =
dim3((args.hdim_v + wg_size - 1) / wg_size, args.seqlen_q, args.batch_q * args.nhead_q);
return g;
}
// reduce single pixel within a wave
template <typename T, typename F>
__device__ constexpr T wave_reduce(T local, F reduce_f)
{
// constexpr int wave_size = 64;
constexpr int reduce_stage = 6; // 1<<6=64
T v_local = local;
#pragma unroll
for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
{
int src_lane = __lane_id() ^ (1 << i_stage);
int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
T v_remote = bit_cast<T>(v_remote_tmp);
v_local = reduce_f(v_local, v_remote);
}
return v_local;
}
// Note: this function must be called after wave_reduce
// Note: better not use this under if...else... with thread divergence (syncthreads)
template <typename T, typename F>
__device__ constexpr T cross_wave_reduce(T local, F reduce_f, T* smem)
{
constexpr int waves = 4;
constexpr int wave_size = 64;
int lane_id = threadIdx.x % wave_size;
__syncthreads();
smem[threadIdx.x] = local;
__syncthreads();
// the data within single wave is the same
// but for simplicity, we still use data from each lane.
T v_local = smem[lane_id];
#pragma unroll
for(int i_stage = 1; i_stage < waves; i_stage++)
{
T v_remote = smem[i_stage * wave_size + lane_id];
v_local = reduce_f(v_local, v_remote);
}
return v_local;
}
// kernel entry point
__device__ void operator()(naive_attention_fwd_args args)
{
constexpr int wg_size = get_block_size();
__shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough
int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
int i_sq = blockIdx.y; // index of seqlen_q
int i_batch = blockIdx.z; // index of batch_q * nhead_q
int i_bq = i_batch / args.nhead_q; // index of batch_q
int i_hq = i_batch % args.nhead_q; // index of nhead_q
int i_bk = i_bq / args.batch_ratio_kv;
int i_hk = i_hq / args.nhead_ratio_kv;
void* page_table_ptr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return reinterpret_cast<int*>(args.page_table_ptr) + i_bq * args.max_pages_per_seq;
}
else
{
return nullptr;
}
}();
auto q_addr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return addresser<QType, QLayout>{
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return addresser<QType, QLayout>{
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
}
}();
auto k_addr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return addresser<KType, KLayout>{
args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim, args.k_ptr};
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return page_addresser<KType, KLayout>{
args.page_size, args.nhead_kv, args.hdim, args.k_ptr, page_table_ptr};
}
}();
auto v_addr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return addresser<VType, VLayout>{
args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim_v, args.v_ptr};
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return page_addresser<VType, VLayout>{
args.page_size, args.nhead_kv, args.hdim_v, args.v_ptr, page_table_ptr};
}
}();
auto o_addr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return addresser<OType, OLayout>{
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return addresser<OType, OLayout>{
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
}
}();
q_addr.init(i_bq, i_hq);
k_addr.init(i_bk, i_hk);
v_addr.init(i_bk, i_hk);
o_addr.init(i_bq, i_hq);
auto f_max = [](auto x_, auto y_) { return max(x_, y_); };
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
auto f_absmax_f32 = [](float v_0_, float v_1_) {
float rtn;
asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
return rtn;
};
int seqlen_kv = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return args.seqlen_kv;
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return reinterpret_cast<int*>(args.context_len_ptr)[i_bq];
}
}();
SoftmaxType row_max = -numeric<SoftmaxType>::infinity();
SoftmaxType l{0};
AccType o_acc = {0};
int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
float qf_scale = .0f;
kvscale_addresser<KVScaleType> kvscale_addr{args.nhead_kv, args.hdim, args.kvscale_ptr};
if constexpr(is_kvcache_i8_forward_quant)
{
// AccType is i32 now, seqlen_q = 1, hdim up to 256
float q = 0;
float k_s = 0;
if(static_cast<int>(threadIdx.x) < args.hdim)
{
q = type_convert<float>(q_addr.load(0, threadIdx.x));
k_s = type_convert<float>(kvscale_addr.load(i_hk, threadIdx.x, 0));
}
// 1) we apply the k scale to q
float q_forwarded = q * k_s;
// 2) apply smooth-quant
// find absmax
float qf_max = wave_reduce(q_forwarded, f_absmax_f32);
qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<float*>(smem));
// per-token scale
qf_scale = qf_max / 127.0;
// devide by scale
q = q / qf_scale;
// fp32->i8
int8_t quantized_q = static_cast<int8_t>(q);
__syncthreads();
reinterpret_cast<int8_t*>(smem)[threadIdx.x] = quantized_q;
__syncthreads();
// after above process, we have 2 data
// 1) int8 q data stored in smem(no need to reload)
// 2) per-token scale qf_scale, to be mul after 1st gemm
}
for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
{
int i_sk = i_loop1 * wg_size + threadIdx.x;
// gemm-1
SoftmaxType s_softmax = -numeric<SoftmaxType>::infinity();
if(i_sk < seqlen_kv)
{
AccType s_acc{0}; // clear for every loop
for(auto i_dq = 0; i_dq < args.hdim; i_dq++)
{
if constexpr(is_kvcache_i8_forward_quant)
{
int8_t q = reinterpret_cast<int8_t*>(smem)[i_dq];
auto k = k_addr.load(i_sk, i_dq);
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
}
else
{
auto q = q_addr.load(i_sq, i_dq); // q will have duplicate load
auto k = k_addr.load(i_sk, i_dq);
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
}
}
// scale
s_softmax = type_convert<SoftmaxType>(s_acc);
s_softmax *=
type_convert<SoftmaxType>(args.scale_s * ck_tile::log2e_v<SoftmaxType>);
if constexpr(is_kvcache_i8_forward_quant)
{
s_softmax *= qf_scale; // post scale the per-token factor
}
}
// s->p
float pf_scale = 0.; // used for i8 quant
{
// softmax, find max
SoftmaxType old_max = row_max;
SoftmaxType cur_max = wave_reduce(s_softmax, f_max);
cur_max = cross_wave_reduce(cur_max, f_max, reinterpret_cast<SoftmaxType*>(smem));
row_max = max(old_max, cur_max); // update row_max
// softmax, exp(i_elem - max)
SoftmaxType p_compute = __builtin_amdgcn_exp2f(s_softmax - row_max);
// compute exp_sum
SoftmaxType row_sum = wave_reduce(p_compute, f_sum);
row_sum = cross_wave_reduce(row_sum, f_sum, reinterpret_cast<SoftmaxType*>(smem));
// l, pre-scall o_acc
SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
l = tmp * l + row_sum;
o_acc = type_convert<AccType>(type_convert<SoftmaxType>(o_acc) * tmp);
// prepare the p_compute into smem, to let every thread read same p_compute and do
// 2nd gemm
if constexpr(is_kvcache_i8_forward_quant)
{
float v_s = 0;
if(static_cast<int>(threadIdx.x) < args.hdim_v)
{
v_s = type_convert<float>(kvscale_addr.load(i_hk, threadIdx.x, 1));
}
// 1) we apply the v scale to p
float p_forwarded = p_compute * v_s;
// 2) apply smooth-quant
// find absmax
float pf_max = wave_reduce(p_forwarded, f_absmax_f32);
pf_max =
cross_wave_reduce(pf_max, f_absmax_f32, reinterpret_cast<float*>(smem));
// per-token scale
pf_scale = pf_max / 127.0;
// devide by scale
p_compute = p_compute / pf_scale;
// fp32->i8
int8_t quantized_p = static_cast<int8_t>(p_compute);
__syncthreads();
reinterpret_cast<int8_t*>(smem)[threadIdx.x] = quantized_p;
__syncthreads();
// after above process, we have 2 data
// 1) int8 p data stored in smem(no need to reload)
// 2) per-token scale pf_scale, to be mul after 2nd gemm
}
else
{
__syncthreads();
reinterpret_cast<PType*>(smem)[threadIdx.x] = type_convert<PType>(p_compute);
__syncthreads();
}
}
// gemm-2, simple loop over vector by vector
constexpr int gemm_2_loop = wg_size / p_vec_elem;
{
AccType o_acc_local = {0};
int sk_start = i_loop1 * wg_size; // we start from the first seqlen_kv element
for(int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
{
p_vec_type p_vec = reinterpret_cast<p_vec_type*>(smem)[i_loop2];
#pragma unroll
for(int i_j = 0; i_j < p_vec_elem; i_j++)
{
int sv_offset = i_loop2 * p_vec_elem + i_j;
int i_sv = sk_start + sv_offset;
VType v = 0.f;
if(i_dv < args.hdim_v && i_sv < seqlen_kv)
{
v = v_addr.load(i_sv, i_dv);
}
o_acc_local += type_convert<AccType>(p_vec[i_j]) * type_convert<AccType>(v);
}
}
if constexpr(is_kvcache_i8_forward_quant)
{
// apply pr scale to local acc
o_acc_local =
type_convert<AccType>(type_convert<float>(o_acc_local) * pf_scale);
}
o_acc += o_acc_local;
}
}
// post scale o_acc
{
SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
o_acc = type_convert<AccType>(type_convert<SoftmaxType>(o_acc) * tmp);
}
// store O
if(i_dv < args.hdim_v)
o_addr.store(type_convert<OType>(o_acc), i_sq, i_dv);
}
};
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
{ \
using ktraits_ = \
naive_attention_fwd_kernel_traits<static_cast<naive_attention_variation_enum>( \
variation_)>; \
using k_ = naive_attention_fwd_kernel<q_type_, \
k_type_, \
v_type_, \
o_type_, \
acc_type_, \
q_layout_, \
k_layout_, \
v_layout_, \
o_layout_, \
ktraits_>; \
dim3 grids = k_::get_grid_size(a); \
r = ck_tile::launch_kernel(s, \
ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \
}
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_() \
if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
t.o_layout == "bshd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
t.v_layout == "bhsd" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
t.v_layout == "phds" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 2; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
}
//
CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
naive_attention_fwd_args a,
ck_tile::stream_config s)
{
float r = -1;
// TODO: do not explicitly create too much instance!
if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16")
{
using q_type_ = fp16_t;
using k_type_ = fp16_t;
using v_type_ = fp16_t;
using o_type_ = fp16_t;
using acc_type_ = float;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16")
{
using q_type_ = bf16_t;
using k_type_ = bf16_t;
using v_type_ = bf16_t;
using o_type_ = bf16_t;
using acc_type_ = float;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16")
{
using q_type_ = bf16_t;
using k_type_ = int8_t;
using v_type_ = int8_t;
using o_type_ = bf16_t;
using acc_type_ = int32_t; // NOTE!
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "fp16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "fp16")
{
using q_type_ = fp16_t;
using k_type_ = int8_t;
using v_type_ = int8_t;
using o_type_ = fp16_t;
using acc_type_ = int32_t; // NOTE!
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
return r;
}
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_
} // namespace ck_tile
...@@ -7,6 +7,7 @@ import copy ...@@ -7,6 +7,7 @@ import copy
NS = 'ck_tile' NS = 'ck_tile'
OPS = 'ops' OPS = 'ops'
REF = 'ref'
OPS_COMMON = 'common' # common header will be duplicated into ops/* other module OPS_COMMON = 'common' # common header will be duplicated into ops/* other module
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
...@@ -29,6 +30,9 @@ class submodule_t: ...@@ -29,6 +30,9 @@ class submodule_t:
def push(self, f): def push(self, f):
if len(f.parents) != 1: # ignore ./xxx.hpp if len(f.parents) != 1: # ignore ./xxx.hpp
mod = get_module(f) mod = get_module(f)
# ref is supposed to include one header on demand
if mod == REF:
return
if mod == OPS: if mod == OPS:
if mod not in self.m.keys(): if mod not in self.m.keys():
self.m[mod] = dict() self.m[mod] = dict()
......
...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = ...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 64, 8, 8, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
......
...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std ...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef __gfx94__ #ifdef __gfx94__
// Compute friendly // Compute friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std: ...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
......
...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int StrideB, int StrideB,
int StrideC, int StrideC,
int BatchCount, int BatchCount,
int KBatch,
int n_warmup, int n_warmup,
int n_iter, int n_iter,
uint64_t rotating = 0) uint64_t rotating = 0)
...@@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device op instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
std::unique_ptr<tensor_operation::device::BaseArgument> argument_ptr; std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
// false branch for multi d dl kernel
argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
BatchCount,
StrideA,
StrideB,
{},
StrideC,
BatchStrideA,
BatchStrideB,
{},
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run( if(KBatch > 0)
argument_ptr.get(), {
StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count}); kbatch_list = {KBatch};
}
std::size_t flop = std::size_t(2) * BatchCount * M * N * K; for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
BatchCount,
StrideA,
StrideB,
{},
StrideC,
BatchStrideA,
BatchStrideB,
{},
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string op_name = op_ptr->GetTypeString();
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + float ave_time = invoker_ptr->Run(
sizeof(CDataType) * M * N) * argument_ptr.get(),
BatchCount; StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
float gb_per_sec = num_btype / 1.E6 / ave_time; std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N) *
BatchCount;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops) float gb_per_sec = num_btype / 1.E6 / ave_time;
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
{ << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl;
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
if(do_log) if(do_verification)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
<< std::endl;
LogRangeAsType<float>( if(do_log)
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") {
<< std::endl; LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
<< std::endl;
}
} }
} }
} else
else {
{ std::cout << op_ptr->GetTypeString() << " does not support this problem"
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; << std::endl;
}
} }
} }
...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K
<< " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC << " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC
<< ": " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " KBatch = " << best_kbatch << ": " << best_ave_time << " ms, " << best_tflops
<< " GB/s, " << best_op_name << std::endl; << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return pass;
} }
......
...@@ -144,6 +144,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -144,6 +144,7 @@ bool profile_gemm_universal_impl(int do_verification,
} }
std::string best_op_name; std::string best_op_name;
std::optional<std::string> best_op_object_name;
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
...@@ -225,7 +226,8 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -225,7 +226,8 @@ bool profile_gemm_universal_impl(int do_verification,
} }
} }
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
std::optional<std::string> op_obj_name = op_ptr->GetObjectName();
float ave_time = invoker_ptr->Run(argument_ptr.get(), float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, StreamConfig{nullptr,
...@@ -251,11 +253,12 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -251,11 +253,12 @@ bool profile_gemm_universal_impl(int do_verification,
if(tflops > best_tflops && ave_time > 1e-10) if(tflops > best_tflops && ave_time > 1e-10)
{ {
best_op_name = op_name; best_op_name = op_name;
best_tflops = tflops; best_op_object_name = op_obj_name;
best_ave_time = ave_time; best_tflops = tflops;
best_gb_per_sec = gb_per_sec; best_ave_time = ave_time;
best_kbatch = kbatch_curr; best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
} }
} }
else else
...@@ -306,6 +309,9 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -306,6 +309,9 @@ bool profile_gemm_universal_impl(int do_verification,
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl; << " GB/s, " << best_op_name << std::endl;
if(best_op_object_name)
std::cout << best_op_object_name.value() << std::endl;
return pass; return pass;
} }
......
...@@ -31,7 +31,7 @@ enum struct GemmDataType ...@@ -31,7 +31,7 @@ enum struct GemmDataType
int profile_batched_gemm_universal(int argc, char* argv[]) int profile_batched_gemm_universal(int argc, char* argv[])
{ {
if(argc != 18 && argc != 21) if(argc != 19 && argc != 22)
{ {
// clang-format off // clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg7: time kernel (0=n0, 1=yes)\n");
printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); printf("arg8 to 18: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount, KBatch\n");
printf("optional:\n"); printf("optional:\n");
printf("arg18: number of warm-up cycles (default 1)\n"); printf("arg19: number of warm-up cycles (default 1)\n");
printf("arg19: number of iterations (default 10)\n"); printf("arg20: number of iterations (default 10)\n");
printf("arg20: memory for rotating buffer (default 0, size in MB)\n"); printf("arg21: memory for rotating buffer (default 0, size in MB)\n");
// clang-format on // clang-format on
exit(1); exit(1);
} }
...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
uint64_t rotating = 0; uint64_t rotating = 0;
if(argc == 21) if(argc == 22)
{ {
n_warmup = std::stoi(argv[18]); n_warmup = std::stoi(argv[19]);
n_iter = std::stoi(argv[19]); n_iter = std::stoi(argv[20]);
rotating = std::stoull(argv[20]) * 1024 * 1024; rotating = std::stoull(argv[21]) * 1024 * 1024;
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
const int BatchStrideC = std::stoi(argv[16]); const int BatchStrideC = std::stoi(argv[16]);
const int BatchCount = std::stoi(argv[17]); const int BatchCount = std::stoi(argv[17]);
const int KBatch = std::stoi(argv[18]);
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t; using F8 = ck::f8_t;
...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
StrideB_, StrideB_,
StrideC_, StrideC_,
BatchCount, BatchCount,
KBatch,
n_warmup, n_warmup,
n_iter, n_iter,
rotating); rotating);
......
...@@ -332,7 +332,7 @@ def main(): ...@@ -332,7 +332,7 @@ def main():
table_name="ck_fmha_bwd_tflops" table_name="ck_fmha_bwd_tflops"
tflops_base = get_baseline(table_name,conn) tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn) store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
conn.close() conn.close()
#compare the results to the baseline if baseline exists #compare the results to the baseline if baseline exists
......
# Currently ck_tile is only built on gfx9 # Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9") if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_gemm_mem_pipeline test_gemm_mem_pipeline.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp)
endif() endif()
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "test_gemm_mem_pipeline_util.hpp" #include "test_gemm_pipeline_util.hpp"
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using F32 = float; using F32 = float;
...@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, ...@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>; ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>; ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave> std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>; >;
// clang-format on // clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes); TYPED_TEST_SUITE(TestCkTileGemmPipeline, KernelTypes);
#include "test_gemm_mem_pipeline_ut_cases.inc" #include "test_gemm_pipeline_ut_cases.inc"
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#pragma once #pragma once
TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) TYPED_TEST(TestCkTileGemmPipeline, SmallM)
{ {
std::vector<int> Ms{1, 2, 3, 4, 5, 6}; std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) ...@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) ...@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) TYPED_TEST(TestCkTileGemmPipeline, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{127};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) ...@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, Regular) TYPED_TEST(TestCkTileGemmPipeline, Regular)
{ {
std::vector<int> Ms{512}; std::vector<int> Ms{512};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular) ...@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, NotSupportedArgument) TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument)
{ {
constexpr int M = 512; constexpr int M = 512;
constexpr int N = 1025; constexpr int N = 1025;
......
...@@ -11,18 +11,24 @@ ...@@ -11,18 +11,24 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
enum struct GemmPipelineType
{
Mem,
Comp
};
template <typename Tuple> template <typename Tuple>
class TestCkTileGemmMemPipeline : public ::testing::Test class TestCkTileGemmPipeline : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>; using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>; using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>; using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
// TODO: expose tile size through test t-param ? // TODO: expose tile size through test t-param ?
struct gemm_args struct gemm_args
...@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = std::conditional_t<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
...@@ -85,15 +96,26 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -85,15 +96,26 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using GemmPipeline =
ck_tile::UniversalGemmPipelineProblem<ADataType, std::conditional_t<PipelineType == GemmPipelineType::Mem,
BDataType, ck_tile::GemmPipelineAgBgCrMem<
AccDataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
GemmShape, BDataType,
Traits, AccDataType,
Scheduler, GemmShape,
has_hot_loop_v, Traits,
tail_number_v>>; Scheduler,
has_hot_loop_v,
tail_number_v>>,
ck_tile::GemmPipelineAgBgCrCompV3<
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
Traits,
Scheduler,
has_hot_loop_v,
tail_number_v>>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment