".github/workflows/release-docker-deepep.yml" did not exist on "18317ddc13bc403749fe9f99ef5726796f855b0e"
Commit 67ab3896 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Merge branch 'develop' into gemm_getname

parents 8adaf418 d5c8a334
......@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
......@@ -36,6 +37,8 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
......@@ -53,6 +56,7 @@ struct Layernorm2dFwdPipelineTwoPass
template <typename XWindow,
typename XResidualWindow,
typename XBiasWindow,
typename GammaWindow,
typename BetaWindow,
typename YWindow,
......@@ -64,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_,
const BetaWindow& beta_window_,
YWindow& y_window,
......@@ -77,8 +82,11 @@ struct Layernorm2dFwdPipelineTwoPass
void* smem,
Epilogue) const
{
static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x_bias_window = make_tile_window(
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window(
......@@ -102,24 +110,35 @@ struct Layernorm2dFwdPipelineTwoPass
int max_count =
(num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_welford = Policy::template GetBlockWelford<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
auto block_welford_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N});
move_tile_window(x_bias_window, {Block_N});
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
......@@ -133,11 +152,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_welford(acc, mean, var, cur_count, max_count);
block_norm_reduce(acc, mean, var, cur_count, max_count);
}
block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_norm_reduce_sync(mean, var, cur_count);
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std
......@@ -165,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
......@@ -172,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x);
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
......@@ -207,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});
......
......@@ -7,6 +7,19 @@
namespace ck_tile {
enum class Layernorm2dXBiasEnum
{
NO_BIAS = 0,
// add bias before fused add
ADD_BIAS = 1,
};
// clang-format off
template<Layernorm2dXBiasEnum> struct Layernorm2dXBiasEnumName;
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::NO_BIAS> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::ADD_BIAS> { static constexpr const char * name = "xbias"; };
// clang-format on
enum class Layernorm2dFusedAddEnum
{
NO_ADD = 0,
......@@ -40,7 +53,9 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template <bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
Layernorm2dXBiasEnum kXbias_,
Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_>
struct Layernorm2dFwdTraits
......@@ -48,7 +63,9 @@ struct Layernorm2dFwdTraits
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dXBiasEnum kXbias = kXbias_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
};
......
......@@ -3,9 +3,8 @@
#pragma once
#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -4,22 +4,23 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = void>
struct BlockWelford
struct BlockNormReduce
{
using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
CK_TILE_DEVICE constexpr BlockWelford() {}
CK_TILE_DEVICE constexpr BlockNormReduce() {}
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN welford will have different
// max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
// calculation of max_count_
// -> use block_welford_calculate_max_count to compute
template <typename XDistributedTensor_,
......@@ -40,18 +41,24 @@ struct BlockWelford
if(cur_count_ < max_count_)
{
++cur_count_;
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
welford_update(mean_tensor(out_dstr_idx),
var_tensor(out_dstr_idx),
x,
cur_count_,
constant<kFastFDiv>{});
if(kWelford)
{
welford_update(mean_tensor(out_dstr_idx),
var_tensor(out_dstr_idx),
x,
cur_count_,
constant<kFastFDiv>{});
}
else
{
mean_tensor(out_dstr_idx) += x;
var_tensor(out_dstr_idx) += x * x;
}
});
}
});
......@@ -91,10 +98,11 @@ struct BlockWelford
};
template <typename Problem_, typename Policy_ = void>
struct BlockWelfordSync
struct BlockNormReduceSync
{
using Problem = remove_cvref_t<Problem_>;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void
......@@ -152,36 +160,48 @@ struct BlockWelfordSync
(number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
// welford merge
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
if(kWelford)
{
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
// norm_reduce merge
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
}
});
}
});
mean_tensor.get_thread_buffer()(i) = v_local_mean;
var_tensor.get_thread_buffer()(i) = v_local_var;
count = v_local_count;
if(kWelford)
{
count = v_local_count;
}
});
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockWelfordCrossWarpSync
struct BlockNormReduceCrossWarpSync
{
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
static constexpr bool kWelford = Problem::kWelford;
using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
......@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
// Note: we always pack everything into fp32x4
fp32x4_t* smem_ptr = reinterpret_cast<fp32x4_t*>(smem);
smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
......@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
if(lane_id == 0)
{
static_for<0, thread_buf_size, 1>{}([&](auto i) {
fp32x4_t local_scratch_;
smem_dtype local_scratch_;
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
local_scratch_[2] = bit_cast<float>(count);
if(kWelford)
{
local_scratch_[2] = bit_cast<float>(count);
}
smem_ptr[smem_offset + i * num_warps] = local_scratch_;
});
}
......@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
// load from smem. here we let everythread to do compute :)
index_t local_warp_id = warp_id / num_reduce_warps;
index_t local_smem_os = local_warp_id * num_reduce_warps;
fp32x4_t all_scratch[thread_buf_size * num_reduce_warps];
smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] =
......@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
// TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = bit_cast<DataType>(v_local[1]);
auto v_local_count = bit_cast<int>(v_local[2]);
auto v_local = all_scratch[i_0 * num_reduce_warps];
auto v_local_mean = bit_cast<DataType>(v_local[0]);
auto v_local_var = bit_cast<DataType>(v_local[1]);
int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
// further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const fp32x4_t v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
const auto v_remote_count = bit_cast<int>(v_remote[2]);
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
if(kWelford)
{
const auto v_remote_count = bit_cast<int>(v_remote[2]);
welford_merge(v_local_mean,
v_local_var,
v_local_count,
v_remote_mean,
v_remote_var,
v_remote_count,
constant<kFastFDiv>{});
}
else
{
v_local_mean += v_remote_mean;
v_local_var += v_remote_var;
}
});
mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
var_tensor.get_thread_buffer()(i_0) = v_local_var;
count = v_local_count;
if(kWelford)
count = v_local_count;
});
}
};
......
......@@ -7,13 +7,18 @@
namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_, bool kFastFDiv_>
struct BlockWelfordProblem
template <typename XDataType_,
typename ComputeDataType_,
typename BlockShape_,
bool kFastFDiv_,
bool kWelford_>
struct BlockNormReduceProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
};
} // namespace ck_tile
......@@ -13,13 +13,18 @@ namespace ck_tile {
enum class naive_attention_layout_enum
{
BSHD, // [batch, seqlen, nhead, hdim]
BHSD, // [batch, nhead, seqlen, hdim]
BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
PHSD, // [pages, nhead, page_size, hdim]
DEFAULT, // maybe this tensor is not used, set some irrelevant value
BSHD, // [batch, seqlen, nhead, hdim]
BHSD, // [batch, nhead, seqlen, hdim]
BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
PHSD, // [pages, nhead, page_size, hdim]
// PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
PHDS, // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
// scale layout used for dynamic dequant
SCALE_HS, // [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant
SCALE_SH, // [tokens, nhead]
};
// will used to specialize kernel variation
......@@ -30,6 +35,15 @@ enum class naive_attention_variation_enum
DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache
};
enum class naive_attention_quant_algo
{
NO = 0,
KV_8BIT_PERHEAD = 1,
// FP8/INT8 quant for KVCache, per-token quant
// [num_tokens, nhead, hdim] -> [nhead, num_tokens]
KV_8BIT_PERTOKEN = 2,
};
// TODO: for simplicity, this will be used as host/device arg
struct naive_attention_fwd_args
{
......@@ -40,7 +54,8 @@ struct naive_attention_fwd_args
void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a
// number, not cumsum)
void* page_table_ptr; // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
void* kvscale_ptr; // [nhead, 2(kv), hdim] used for kvcache dequant
void* kscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
void* vscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
float scale_s;
int hdim;
int hdim_v; // could be cross-attn, where V and Q/K hdim are different
......@@ -54,6 +69,7 @@ struct naive_attention_fwd_args
int nhead_ratio_kv; // nhead_q / nhead_kv
int page_size; // if paged, the seqlen-kv per each block
int max_pages_per_seq;
int max_kv_tokens; // used as stride to access kv scale ptr
};
// this is trait for host API
......@@ -67,14 +83,16 @@ struct naive_attention_fwd_traits
std::string k_layout;
std::string v_layout;
std::string o_layout;
int variation; // sync with naive_attention_variation_enum
int variation; // sync with naive_attention_variation_enum
int quant_algo; // sync with naive_attention_quant_algo
};
// this is trait for kernel template
template <naive_attention_variation_enum variation_>
template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
struct naive_attention_fwd_kernel_traits
{
static constexpr naive_attention_variation_enum variation = variation_;
static constexpr naive_attention_quant_algo quant_algo = quant_algo_;
};
// for simplicity, please do not use const-reference type for the template type
......@@ -83,28 +101,39 @@ template <typename QType,
typename VType,
typename OType,
typename AccType,
typename KVScaleType,
naive_attention_layout_enum QLayout,
naive_attention_layout_enum KLayout,
naive_attention_layout_enum VLayout,
naive_attention_layout_enum OLayout,
naive_attention_layout_enum KScaleLayout,
naive_attention_layout_enum VScaleLayout,
typename Traits>
struct naive_attention_fwd_kernel
{
static constexpr bool is_kvcache_i8 =
std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t> && sizeof(QType) != 1;
std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
static constexpr bool is_kvcache_fp8 =
std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
// kvcache-i8 will have per head scale, we apply this scale to Q/P matrix instead of original
// K/V matrix. This can speed up conversion since Q/P usually is fp16/bf16/fp32
static constexpr bool is_kvcache_i8_forward_quant = is_kvcache_i8;
static constexpr int v_per_token_quant_group_size = 64;
// TODO: hardcode
using KVScaleType = float;
using SoftmaxType = float;
using PType = VType; // src A of gemm2, same type as V
using SoftmaxType = float; // always using float to do softmax compute
using QuantComputeType = float; // used for quant/dequant scale compute
using QCompute = KType; // src A of gemm1, same type as K
using PType = VType; // src A of gemm2, same type as V
using OAccType = float; // always float, in case int8 FA
using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>;
static constexpr int p_vec_elem = vector_traits<p_vec_type>::vector_size;
// clang-format off
template <typename T_> struct scale_max { static constexpr float value = 1; /* dummy code */ };
template <> struct scale_max<int8_t> { static constexpr float value = 127.0; };
template <> struct scale_max<fp8_t> { static constexpr float value = 240.0; };
// clang-format on
__host__ __device__ naive_attention_fwd_kernel() {}
template <typename T, naive_attention_layout_enum Layout>
......@@ -198,24 +227,31 @@ struct naive_attention_fwd_kernel
__device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
};
template <typename T>
template <typename T, naive_attention_layout_enum Layout>
struct kvscale_addresser
{
int h, d; // nhead, hdim
int s, h, d; // seqlen(tokens), nhead, hdim
T* base_ptr;
__device__ kvscale_addresser(int h_, int d_, void* p_)
: h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
__device__ kvscale_addresser(int s_, int h_, int d_, void* p_)
: s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
{
}
__device__ int get_offset(int i_h, int i_d, int i_kv /*0 or 1*/)
__device__ int get_offset(int i_s, int i_h, int i_d)
{
if constexpr(Layout == naive_attention_layout_enum::SCALE_HS)
{
// [nhead, tokens]
(void)i_d;
return i_h * s + i_s;
}
else if constexpr(Layout == naive_attention_layout_enum::DEFAULT)
{
return 0;
}
// [h, 2, d]
return i_h * 2 * d + i_kv * d + i_d;
}
__device__ T load(int i_h, int i_d, int i_kv)
{
return base_ptr[get_offset(i_h, i_d, i_kv)];
// return i_h * 2 * d + i_kv * d + i_d;
}
__device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
};
__device__ __host__ static constexpr int get_block_size() { return 256; }
......@@ -282,12 +318,13 @@ struct naive_attention_fwd_kernel
__device__ void operator()(naive_attention_fwd_args args)
{
constexpr int wg_size = get_block_size();
__shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough
int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
int i_sq = blockIdx.y; // index of seqlen_q
int i_batch = blockIdx.z; // index of batch_q * nhead_q
int i_bq = i_batch / args.nhead_q; // index of batch_q
int i_hq = i_batch % args.nhead_q; // index of nhead_q
__shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough
char* smem_quant_q = smem + wg_size * 2 * sizeof(float); // second half, should enough
int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
int i_sq = blockIdx.y; // index of seqlen_q
int i_batch = blockIdx.z; // index of batch_q * nhead_q
int i_bq = i_batch / args.nhead_q; // index of batch_q
int i_hq = i_batch % args.nhead_q; // index of nhead_q
int i_bk = i_bq / args.batch_ratio_kv;
int i_hk = i_hq / args.nhead_ratio_kv;
......@@ -360,9 +397,10 @@ struct naive_attention_fwd_kernel
auto f_max = [](auto x_, auto y_) { return max(x_, y_); };
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
auto f_absmax_f32 = [](float v_0_, float v_1_) {
float rtn;
asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
return rtn;
// float rtn;
// asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
// return rtn;
return max(abs(v_0_), abs(v_1_));
};
int seqlen_kv = [&]() {
......@@ -378,45 +416,82 @@ struct naive_attention_fwd_kernel
SoftmaxType row_max = -numeric<SoftmaxType>::infinity();
SoftmaxType l{0};
AccType o_acc = {0};
// AccType o_acc = {0};
OAccType o_acc = {0};
int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
float qf_scale = .0f;
kvscale_addresser<KVScaleType> kvscale_addr{args.nhead_kv, args.hdim, args.kvscale_ptr};
int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
QuantComputeType q_dequant_scale = .0f;
kvscale_addresser<KVScaleType, KScaleLayout> kscale_addr{
args.max_kv_tokens, args.nhead_kv, args.hdim, args.kscale_ptr};
kvscale_addresser<KVScaleType, VScaleLayout> vscale_addr{
args.max_kv_tokens, args.nhead_kv, args.hdim_v, args.vscale_ptr};
if constexpr(is_kvcache_i8_forward_quant)
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{
// AccType is i32 now, seqlen_q = 1, hdim up to 256
float q = 0;
float k_s = 0;
AccType q = 0;
AccType k_s = 0;
if(static_cast<int>(threadIdx.x) < args.hdim)
{
q = type_convert<float>(q_addr.load(0, threadIdx.x));
k_s = type_convert<float>(kvscale_addr.load(i_hk, threadIdx.x, 0));
q = type_convert<AccType>(q_addr.load(0, threadIdx.x));
k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0));
}
// 1) we apply the k scale to q
float q_forwarded = q * k_s;
AccType q_forwarded = q * k_s;
// 2) apply smooth-quant
// find absmax
float qf_max = wave_reduce(q_forwarded, f_absmax_f32);
qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<float*>(smem));
AccType qf_max = wave_reduce(q_forwarded, f_absmax_f32);
qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType*>(smem));
// per-token scale
qf_scale = qf_max / 127.0;
q_dequant_scale = type_convert<QuantComputeType>(qf_max) / scale_max<QCompute>::value;
// devide by scale
q = q / qf_scale;
q = q / q_dequant_scale;
// fp32->i8
int8_t quantized_q = static_cast<int8_t>(q);
QCompute quantized_q = static_cast<QCompute>(q);
__syncthreads();
reinterpret_cast<int8_t*>(smem)[threadIdx.x] = quantized_q;
reinterpret_cast<QCompute*>(smem)[threadIdx.x] = quantized_q;
__syncthreads();
// after above process, we have 2 data
// 1) int8 q data stored in smem(no need to reload)
// 2) per-token scale qf_scale, to be mul after 1st gemm
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
{
// dyanmic quant q here
float q = 0;
if(static_cast<int>(threadIdx.x) < args.hdim)
{
q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
}
// apply smooth-quant
// find absmax
float q_max = wave_reduce(q, f_absmax_f32);
q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));
// per-token scale
q_dequant_scale =
type_convert<QuantComputeType>(q_max) / scale_max<QCompute>::value;
// devide by scale
q = q / q_dequant_scale;
QCompute quantized_q = type_convert<QCompute>(q);
__syncthreads();
reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
__syncthreads();
// after above process, we have 2 data
// 1) fp8 q data stored in smem(no need to reload from global)
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
}
for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
......@@ -429,33 +504,41 @@ struct naive_attention_fwd_kernel
AccType s_acc{0}; // clear for every loop
for(auto i_dq = 0; i_dq < args.hdim; i_dq++)
{
if constexpr(is_kvcache_i8_forward_quant)
{
int8_t q = reinterpret_cast<int8_t*>(smem)[i_dq];
auto k = k_addr.load(i_sk, i_dq);
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
}
else
{
auto q = q_addr.load(i_sq, i_dq); // q will have duplicate load
auto k = k_addr.load(i_sk, i_dq);
auto q = [&]() {
if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERHEAD ||
Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
return reinterpret_cast<QCompute*>(smem_quant_q)[i_dq];
}
else
return q_addr.load(i_sq, i_dq); // q will have duplicate load
}();
auto k = [&]() { return k_addr.load(i_sk, i_dq); }();
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
}
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
}
// scale
s_softmax = type_convert<SoftmaxType>(s_acc);
s_softmax *=
type_convert<SoftmaxType>(args.scale_s * ck_tile::log2e_v<SoftmaxType>);
if constexpr(is_kvcache_i8_forward_quant)
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{
s_softmax *= q_dequant_scale; // post scale the per-token factor
}
else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
s_softmax *= qf_scale; // post scale the per-token factor
SoftmaxType k_per_token_scale =
type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
s_softmax *= q_dequant_scale;
s_softmax *= k_per_token_scale;
}
}
// s->p
float pf_scale = 0.; // used for i8 quant
QuantComputeType p_dequant_scale = 1.;
{
// softmax, find max
SoftmaxType old_max = row_max;
......@@ -473,41 +556,69 @@ struct naive_attention_fwd_kernel
// l, pre-scall o_acc
SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
l = tmp * l + row_sum;
o_acc = type_convert<AccType>(type_convert<SoftmaxType>(o_acc) * tmp);
o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
// prepare the p_compute into smem, to let every thread read same p_compute and do
// 2nd gemm
if constexpr(is_kvcache_i8_forward_quant)
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{
float v_s = 0;
QuantComputeType v_s = 0;
if(static_cast<int>(threadIdx.x) < args.hdim_v)
{
v_s = type_convert<float>(kvscale_addr.load(i_hk, threadIdx.x, 1));
v_s =
type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
}
// 1) we apply the v scale to p
float p_forwarded = p_compute * v_s;
QuantComputeType p_forwarded = p_compute * v_s;
// 2) apply smooth-quant
// find absmax
float pf_max = wave_reduce(p_forwarded, f_absmax_f32);
pf_max =
cross_wave_reduce(pf_max, f_absmax_f32, reinterpret_cast<float*>(smem));
QuantComputeType pf_max = wave_reduce(p_forwarded, f_absmax_f32);
pf_max = cross_wave_reduce(
pf_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
// per-token scale
pf_scale = pf_max / 127.0;
p_dequant_scale = pf_max / scale_max<PType>::value; // 127.0;
// devide by scale
p_compute = p_compute / pf_scale;
p_compute = p_compute / p_dequant_scale;
// fp32->i8
int8_t quantized_p = static_cast<int8_t>(p_compute);
PType quantized_p = static_cast<PType>(p_compute);
__syncthreads();
reinterpret_cast<int8_t*>(smem)[threadIdx.x] = quantized_p;
reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
__syncthreads();
// after above process, we have 2 data
// 1) int8 p data stored in smem(no need to reload)
// 2) per-token scale pf_scale, to be mul after 2nd gemm
// 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
}
else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
// forward apply the v scale to p_compute, this is compute friendly
auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
p_compute *= v_scale;
// smooth-quant
// find absmax
QuantComputeType p_max = wave_reduce(p_compute, f_absmax_f32);
p_max = cross_wave_reduce(
p_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
// per-token scale
p_dequant_scale = p_max / scale_max<PType>::value; // 240.0;
// devide by scale
p_compute = p_compute / p_dequant_scale;
// fp32->i8
PType quantized_p = type_convert<PType>(p_compute);
__syncthreads();
reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
__syncthreads();
// after above process, we have 2 data
// 1) fp8_t p data stored in smem(no need to reload)
// 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
}
else
{
......@@ -531,29 +642,45 @@ struct naive_attention_fwd_kernel
int sv_offset = i_loop2 * p_vec_elem + i_j;
int i_sv = sk_start + sv_offset;
VType v = 0.f;
VType v = 0;
if(i_dv < args.hdim_v && i_sv < seqlen_kv)
{
v = v_addr.load(i_sv, i_dv);
}
o_acc_local += type_convert<AccType>(p_vec[i_j]) * type_convert<AccType>(v);
AccType v_compute = [&]() { return type_convert<AccType>(v); }();
o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute;
}
}
if constexpr(is_kvcache_i8_forward_quant)
{
// apply pr scale to local acc
o_acc_local =
type_convert<AccType>(type_convert<float>(o_acc_local) * pf_scale);
}
o_acc += o_acc_local;
OAccType post_scale_o_acc_local = [&]() {
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{
// apply pr scale to local acc
return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
p_dequant_scale);
}
else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
// apply pr scale to local acc
return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
p_dequant_scale);
}
else
{
return type_convert<OAccType>(o_acc_local);
}
}();
o_acc += post_scale_o_acc_local;
}
}
// post scale o_acc
{
SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
o_acc = type_convert<AccType>(type_convert<SoftmaxType>(o_acc) * tmp);
o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
}
// store O
......@@ -564,18 +691,21 @@ struct naive_attention_fwd_kernel
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
{ \
using ktraits_ = \
naive_attention_fwd_kernel_traits<static_cast<naive_attention_variation_enum>( \
variation_)>; \
using ktraits_ = naive_attention_fwd_kernel_traits< \
static_cast<naive_attention_variation_enum>(variation_), \
static_cast<naive_attention_quant_algo>(quant_algo_)>; \
using k_ = naive_attention_fwd_kernel<q_type_, \
k_type_, \
v_type_, \
o_type_, \
acc_type_, \
kvscale_type_, \
q_layout_, \
k_layout_, \
v_layout_, \
o_layout_, \
k_scale_layout_, \
v_scale_layout_, \
ktraits_>; \
dim3 grids = k_::get_grid_size(a); \
r = ck_tile::launch_kernel(s, \
......@@ -586,31 +716,37 @@ struct naive_attention_fwd_kernel
if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
t.o_layout == "bshd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
constexpr int variation_ = 0; \
constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
t.v_layout == "bhsd" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 0; \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
t.v_layout == "phds" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 2; \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
constexpr int variation_ = 2; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
}
......@@ -621,40 +757,64 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
{
float r = -1;
// TODO: do not explicitly create too much instance!
if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16")
if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16" &&
t.quant_algo == 0)
{
using q_type_ = fp16_t;
using k_type_ = fp16_t;
using v_type_ = fp16_t;
using o_type_ = fp16_t;
using acc_type_ = float;
using kvscale_type_ = float;
constexpr int quant_algo_ = 0;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16" &&
t.quant_algo == 0)
{
using q_type_ = fp16_t;
using k_type_ = fp16_t;
using v_type_ = fp16_t;
using o_type_ = fp16_t;
using acc_type_ = float;
using q_type_ = bf16_t;
using k_type_ = bf16_t;
using v_type_ = bf16_t;
using o_type_ = bf16_t;
using acc_type_ = float;
using kvscale_type_ = float;
constexpr int quant_algo_ = 0;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16")
else if(t.q_type == "bf16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "bf16" &&
t.quant_algo == 2)
{
using q_type_ = bf16_t;
using k_type_ = bf16_t;
using v_type_ = bf16_t;
using o_type_ = bf16_t;
using acc_type_ = float;
using q_type_ = bf16_t;
using k_type_ = fp8_t;
using v_type_ = fp8_t;
using o_type_ = bf16_t;
using acc_type_ = float; // NOTE!
using kvscale_type_ = float;
constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16")
else if(t.q_type == "fp16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "fp16" &&
t.quant_algo == 2)
{
using q_type_ = bf16_t;
using k_type_ = int8_t;
using v_type_ = int8_t;
using o_type_ = bf16_t;
using acc_type_ = int32_t; // NOTE!
using q_type_ = fp16_t;
using k_type_ = fp8_t;
using v_type_ = fp8_t;
using o_type_ = fp16_t;
using acc_type_ = float; // NOTE!
using kvscale_type_ = float;
constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "fp16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "fp16")
else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16" &&
t.quant_algo == 2)
{
using q_type_ = fp16_t;
using k_type_ = int8_t;
using v_type_ = int8_t;
using o_type_ = fp16_t;
using acc_type_ = int32_t; // NOTE!
using q_type_ = bf16_t;
using k_type_ = int8_t;
using v_type_ = int8_t;
using o_type_ = bf16_t;
using acc_type_ = int32_t; // NOTE!
using kvscale_type_ = float;
constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
return r;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <memory>
#include <vector>
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ADataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
index_t ScaleBlockK>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmV2BScale<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
BScaleDataType,
CDataType,
1,
ScaleBlockK,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemmV2BScale<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
BScaleDataType,
CDataType,
1,
ScaleBlockK,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -238,6 +238,403 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpaddin
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Col,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
......@@ -527,6 +924,109 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
op_ptrs);
}
}
#endif
#if(defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>)
......
......@@ -183,6 +183,10 @@ FOREACH(subdir_path ${dir_list})
message("bf8 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16")
message("bf16 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
message("fp16 instance found!")
set(add_inst 1)
......@@ -195,10 +199,6 @@ FOREACH(subdir_path ${dir_list})
message("fp64 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
message("bf16 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
message("int8 instance found!")
set(add_inst 1)
......
# ONLY XDL_KERNELS
set(GEMM_B_SCALE_INSTANCES)
list(APPEND GEMM_B_SCALE_INSTANCES
device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
)
set_source_files_properties(device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES})
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I4 = pk_i4_t;
using F16 = half_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
#if 0
template <GemmSpecialization GemmSpec>
using device_gemm_xdl_b_scale_f16_i4_f16_mk_nk_mn_comp_instances = std::tuple<
#endif
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//Compute friendly
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
//Latency friendly
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
// Memory friendly v3
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
// Memory friendly v4
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>,
//new Compute friendly kernel
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>,
//new Memory friendly kernel
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 64, 256, 8, 32, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -64,6 +64,43 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp)
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp)
add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMKPadding = GemmSpecialization::MKPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// Can we support this kind of odd case? 224(256) = 28*8 + (4*8)
//DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmMNPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Col,
Row,
Row,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment