Unverified Commit 3c5717df authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into gemm_elementwise_gemm

parents 171b9030 d9f1ead3
...@@ -14,11 +14,19 @@ ...@@ -14,11 +14,19 @@
#include <utility> #include <utility>
#include <variant> #include <variant>
struct FmhaBwdFp16
{
};
struct FmhaBwdBf16
{
};
template <typename DataType> template <typename DataType>
struct FmhaBwdTypeConfig; struct FmhaBwdTypeConfig;
template <> template <>
struct FmhaBwdTypeConfig<ck_tile::half_t> struct FmhaBwdTypeConfig<FmhaBwdFp16>
{ {
using QDataType = ck_tile::half_t; using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t; using KDataType = ck_tile::half_t;
...@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t> ...@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
}; };
template <> template <>
struct FmhaBwdTypeConfig<ck_tile::bf16_t> struct FmhaBwdTypeConfig<FmhaBwdBf16>
{ {
using QDataType = ck_tile::bf16_t; using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t;
...@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments // create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{ {
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.lse_ptr, args.lse_ptr,
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr, args.dq_acc_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale, args.scale,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc, args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc, args.nhead_stride_dq_acc,
args.nhead_stride_dk, args.nhead_stride_dk,
args.nhead_stride_dv, args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.split_stride_dq_acc, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.lse_ptr, args.lse_ptr,
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr, args.dq_acc_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale, args.scale,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc, args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc, args.nhead_stride_dq_acc,
args.nhead_stride_dk, args.nhead_stride_dk,
args.nhead_stride_dv, args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_randval, args.batch_stride_randval,
args.batch_stride_do, args.batch_stride_do,
args.batch_stride_lsed, args.batch_stride_lsed,
args.batch_stride_dq_acc, args.batch_stride_dq_acc,
args.batch_stride_dk, args.batch_stride_dk,
args.batch_stride_dv, args.batch_stride_dv,
args.batch_stride_dbias, args.batch_stride_dbias,
args.split_stride_dq_acc, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "fmha_fwd.hpp" #include "fmha_fwd.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "ck_tile/ref/naive_attention.hpp"
#include "mask.hpp" #include "mask.hpp"
#include "rotary.hpp" #include "rotary.hpp"
#include "utils.hpp" #include "utils.hpp"
...@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) ...@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not") arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
.insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size") .insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q") .insert("h", "8", "num of head, for q")
...@@ -62,7 +63,7 @@ auto create_args(int argc, char* argv[]) ...@@ -62,7 +63,7 @@ auto create_args(int argc, char* argv[])
"-1 to choose s_knew in [1, s] randomly.") "-1 to choose s_knew in [1, s] randomly.")
.insert("s_kpad", .insert("s_kpad",
"-1", "-1",
"seqlen_k stride between 2 tokens, currently used in group-mode only\n" "seqlen_k stride between 2 batches, currently used in group-mode only\n"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
"along seqlen, instead of packed. same as xformer kv_padding") "along seqlen, instead of packed. same as xformer kv_padding")
.insert("d", "128", "head dim for q, k") .insert("d", "128", "head dim for q, k")
...@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[]) ...@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[])
} }
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataTypeConfig>
auto get_elimit(std::string /*init_method*/) auto get_elimit(std::string /*init_method*/)
{ {
double rtol = 1e-3; double rtol = 1e-3;
...@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/) ...@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/)
} }
template <> template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/) auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
...@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/) ...@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
} }
template <> template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method) auto get_elimit<FmhaFwdFp8>(std::string init_method)
{ {
if(init_method == "ui" || init_method == "ni") if(init_method == "ui" || init_method == "ni")
{ {
...@@ -261,7 +262,7 @@ int override_num_splits_if_necessary( ...@@ -261,7 +262,7 @@ int override_num_splits_if_necessary(
return num_splits; return num_splits;
} }
template <typename DataType> template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
...@@ -294,7 +295,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -294,7 +295,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
#if !CK_TILE_FMHA_FWD_APPENDKV_API #if !CK_TILE_FMHA_FWD_APPENDKV_API
if(seqlen_knew != 0) if(seqlen_knew != 0)
{ {
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
<< std::endl;
seqlen_knew = 0; seqlen_knew = 0;
} }
#endif #endif
...@@ -304,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -304,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> || if constexpr(!(std::is_same_v<DataTypeConfig, FmhaFwdFp16> ||
std::is_same_v<DataType, ck_tile::bf16_t>)) std::is_same_v<DataTypeConfig, FmhaFwdBf16>))
{ {
if(0 < rotary_dim) if(0 < rotary_dim)
{ {
...@@ -321,6 +323,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -321,6 +323,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
rotary_dim = 0; rotary_dim = 0;
} }
#endif #endif
// to use fmha_fwd_appendkv(), make sure it's in batch mode
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
if(need_append_kvcache && mode == mode_enum::group)
{
std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
}
if(!(rotary_dim <= hdim_q)) if(!(rotary_dim <= hdim_q))
{ {
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
...@@ -356,22 +365,26 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -356,22 +365,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::endl; << std::endl;
use_cache_batch_idx = false; use_cache_batch_idx = false;
} }
#endif #else
if(0 < page_block_size && use_cache_batch_idx) if(use_cache_batch_idx)
{ {
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " if(0 < page_block_size)
"'cache_batch_idx' option" {
<< std::endl; std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
use_cache_batch_idx = false; "'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
else if(mode == mode_enum::group)
{
std::cerr << "group mode will not use cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
} }
// the input tensor layout for kvcache is same as batch mode #endif
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
if(use_kvcache && mode != mode_enum::batch)
{
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
}
auto [seqlen_qs, seqlen_ks, seqlen_kpads] = auto [seqlen_qs, seqlen_ks, seqlen_kpads] =
decode_seqlen(mode, decode_seqlen(mode,
...@@ -380,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -380,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
arg_parser.get_str("s_k"), arg_parser.get_str("s_k"),
arg_parser.get_str("s_kpad"), arg_parser.get_str("s_kpad"),
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
use_kvcache); need_append_kvcache);
// compute kvcache seqlen_k (before appending knew/vnew) // compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks; auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(), std::transform(cache_seqlen_ks.begin(),
...@@ -416,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -416,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
return atoi(squant_str.c_str()) != 0 ? true : false; return atoi(squant_str.c_str()) != 0 ? true : false;
}(); }();
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max);
scale_p = dtype_max / range_p;
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
scale_o = range_p * range_v / range_o / dtype_max;
}
std::string vlayout = arg_parser.get_str("vlayout"); std::string vlayout = arg_parser.get_str("vlayout");
bool lse = arg_parser.get_bool("lse"); bool lse = arg_parser.get_bool("lse");
...@@ -454,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -454,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
bool s_randval = false; bool s_randval = false;
if(p_drop > 0.0f && do_validation) if(p_drop > 0.0f && do_validation != 0)
{ {
s_randval = true; s_randval = true;
} }
...@@ -487,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -487,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
using TypeConfig = FmhaFwdTypeConfig<DataType>; using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType; using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType; using KDataType = typename TypeConfig::KDataType;
...@@ -501,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -501,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
using OaccDataType = typename TypeConfig::OaccDataType; using OaccDataType = typename TypeConfig::OaccDataType;
using ODataType = typename TypeConfig::ODataType; using ODataType = typename TypeConfig::ODataType;
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float p_dtype_max = v_dtype_max; // assume p and v is the same type
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max);
scale_p = p_dtype_max / range_p;
scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max);
}
// accumulation numbers for performance evaluation // accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q = auto max_seqlen_q =
...@@ -697,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -697,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
else if(init_method == "ufq" || init_method == "uf:q" || else if(init_method == "ufq" || init_method == "uf:q" ||
init_method == "3") // suitable for fp8 quantization init_method == "3") // suitable for fp8 quantization
{ {
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host); ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host); ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host); ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host); ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(vnew_host);
// bias_fp8 = qscale_bias * bias_fp32 // bias_fp8 = qscale_bias * bias_fp32
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k); float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
// Assume bias is in [-1.f, 1.f] in original fp32 // Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host); ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
} }
...@@ -741,8 +757,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -741,8 +757,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf( ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0); 0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cache_seqlen_k_buf( ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
...@@ -763,7 +781,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -763,7 +781,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data()); : seqstart_k_with_padding_host.data());
seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
? seqlen_ks.data()
: nullptr);
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data());
...@@ -976,8 +996,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -976,8 +996,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr = args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); ? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
args.max_seqlen_q = max_seqlen_q; args.max_seqlen_q = max_seqlen_q;
...@@ -1029,6 +1050,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1029,6 +1050,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table; args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size; args.page_block_size = page_block_size;
args.is_gappy = false; // use 'false' for flash-attention integration
args.cache_batch_idx = args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
...@@ -1100,25 +1122,76 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1100,25 +1122,76 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
<< " GB/s" << std::flush; << " GB/s" << std::flush;
if(!do_validation) if(do_validation == 0)
{ {
std::cout << std::flush << std::endl; std::cout << std::flush << std::endl;
return true; return true;
} }
if(do_validation == 2)
{
// NOTE: use gpu to do validation
ck_tile::naive_attention_fwd_traits naive_t;
naive_t.q_type = data_type;
naive_t.k_type = data_type;
naive_t.v_type = data_type;
naive_t.o_type = data_type;
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
naive_t.variation = 0; // TODO?
naive_t.quant_algo = 0;
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
ck_tile::naive_attention_fwd_args naive_a;
naive_a.q_ptr = q_buf.GetDeviceBuffer();
naive_a.k_ptr = k_buf.GetDeviceBuffer();
naive_a.v_ptr = v_buf.GetDeviceBuffer();
naive_a.o_ptr = o_naive_buf.GetDeviceBuffer();
naive_a.scale_s = scale_s;
naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer
naive_a.page_table_ptr =
nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn)
naive_a.hdim = hdim_q;
naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different
naive_a.batch_q = batch;
naive_a.batch_kv = batch;
naive_a.batch_ratio_kv = 1; // batch_q / batch_kv
naive_a.seqlen_q = seqlen_qs[0];
naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field
naive_a.nhead_q = nhead;
naive_a.nhead_kv = nhead_k;
naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv
naive_a.page_size = 0; // if paged, the seqlen-kv for each block
ck_tile::stream_config naive_s{};
naive_attention_fwd(naive_t, naive_a, naive_s);
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
o_buf.FromDevice(o_host.data()); // TODO: ugly
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
bool pass_ = ck_tile::check_err(
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl;
return pass_;
}
o_buf.FromDevice(o_host.data()); o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data()); lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data()); randval_buf.FromDevice(randval_host.data());
auto p_compute_element_func = [&]() { auto p_compute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>) if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::scales{scale_p}; return ck_tile::scales{scale_p};
else else
return ck_tile::identity{}; return ck_tile::identity{};
}(); }();
auto oacc_element_func = [&]() { auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>) if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{}, return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o}); ck_tile::scales{scale_o});
else else
...@@ -1168,7 +1241,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1168,7 +1241,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] = auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
ck_tile::reference_batched_rotary_position_embedding( ck_tile::reference_batched_rotary_position_embedding(
...@@ -1184,13 +1257,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1184,13 +1257,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
k_host_ref.ForEach([&](auto& self, auto i) { k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
}); });
} else { } else {
k_host_ref.ForEach([&](auto& self, auto i) { k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]);
}); });
} }
} else } else
#endif #endif
{ {
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
...@@ -1211,7 +1284,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1211,7 +1284,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] = auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
ck_tile::reference_batched_rotary_position_embedding( ck_tile::reference_batched_rotary_position_embedding(
...@@ -1233,19 +1306,19 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1233,19 +1306,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(0 < page_block_size) { if(0 < page_block_size) {
if(is_v_rowmajor) { if(is_v_rowmajor) {
if(i_perm) { if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
}); });
} else { } else {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]);
}); });
} }
} }
else else
{ {
if(i_perm) { if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
}); });
} else { } else {
...@@ -1440,7 +1513,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1440,7 +1513,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method); auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err( bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
pass &= cur_pass; pass &= cur_pass;
...@@ -1497,15 +1570,15 @@ int main(int argc, char* argv[]) ...@@ -1497,15 +1570,15 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec"); const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16") if(data_type == "fp16")
{ {
return run<ck_tile::half_t>(arg_parser) ? 0 : -2; return run<FmhaFwdFp16>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16") else if(data_type == "bf16")
{ {
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2; return run<FmhaFwdBf16>(arg_parser) ? 0 : -2;
} }
else if(data_type == "fp8") else if(data_type == "fp8")
{ {
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2; return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
...@@ -16,11 +16,35 @@ ...@@ -16,11 +16,35 @@
#include <utility> #include <utility>
#include <variant> #include <variant>
struct FmhaFwdFp16
{
};
struct FmhaFwdBf16
{
};
struct FmhaFwdFp8
{
};
struct FmhaFwdBf8
{
};
struct FmhaFwdFp8Fp16
{
};
struct FmhaFwdFp8Bf16
{
};
template <typename DataType> template <typename DataType>
struct FmhaFwdTypeConfig; struct FmhaFwdTypeConfig;
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::half_t> struct FmhaFwdTypeConfig<FmhaFwdFp16>
{ {
using QDataType = ck_tile::half_t; using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t; using KDataType = ck_tile::half_t;
...@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t> ...@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::bf16_t> struct FmhaFwdTypeConfig<FmhaFwdBf16>
{ {
using QDataType = ck_tile::bf16_t; using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t;
...@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t> ...@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::fp8_t> struct FmhaFwdTypeConfig<FmhaFwdFp8>
{ {
using QDataType = ck_tile::fp8_t; using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t;
...@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t> ...@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::bf8_t> struct FmhaFwdTypeConfig<FmhaFwdBf8>
{ {
using QDataType = ck_tile::bf8_t; using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t; using KDataType = ck_tile::bf8_t;
...@@ -165,6 +189,8 @@ struct fmha_fwd_splitkv_args ...@@ -165,6 +189,8 @@ struct fmha_fwd_splitkv_args
void* block_table_ptr; void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
// nullptr.
const void* cache_batch_idx; const void* cache_batch_idx;
...@@ -173,9 +199,21 @@ struct fmha_fwd_splitkv_args ...@@ -173,9 +199,21 @@ struct fmha_fwd_splitkv_args
// seqlen_k = kargs.seqlen_k // seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode): // or kargs.seqlen_k_ptr[b]
//
// batch mode (kvcache):
// seqlen_q = kargs.seqlen_q // seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k_ptr[b]
// group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
//
// when is_gappy=true:
// seqlen_k = kargs.seqlen_k_ptr[b]
// seqstart_k_ptr[b] now store local offset of each batch
//
// when is_gappy=false:
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
const void* seqstart_k_ptr; const void* seqstart_k_ptr;
const void* seqlen_k_ptr; const void* seqlen_k_ptr;
...@@ -251,7 +289,7 @@ struct fmha_fwd_appendkv_args ...@@ -251,7 +289,7 @@ struct fmha_fwd_appendkv_args
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx; const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
ck_tile::index_t stride_q; ck_tile::index_t stride_q;
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
...@@ -278,92 +316,102 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -278,92 +316,102 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
// create group mode kernel arguments // create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode) if constexpr(FmhaKernel::kIsGroupMode)
{ {
return FmhaKernel::MakeKargs(args.q_ptr, return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
args.scale_o, args.scale_o,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_o, args.stride_o,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval, args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
return FmhaKernel::MakeKargs(args.q_ptr, return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
args.scale_o, args.scale_o,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_o, args.stride_o,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_randval, args.batch_stride_randval,
args.batch_stride_lse, args.batch_stride_lse,
args.batch_stride_o, args.batch_stride_o,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval, args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); if constexpr(FmhaKernel::kIsGroupMode)
return ck_tile::make_tuple(kargs, grids); {
dim3 grids = FmhaKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
return ck_tile::make_tuple(kargs, grids);
}
else
{
dim3 grids =
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
return ck_tile::make_tuple(kargs, grids);
}
} }
template <typename Kernel> template <typename Kernel>
...@@ -389,6 +437,10 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -389,6 +437,10 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.num_splits, args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.is_gappy,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
args.stride_q, args.stride_q,
...@@ -458,8 +510,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -458,8 +510,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
} }
}(); }();
dim3 grids = dim3 grids = Kernel::GridSize(
Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits);
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
...@@ -667,7 +719,6 @@ std::string fmha_fwd_splitkv_get_name_(); ...@@ -667,7 +719,6 @@ std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
bool kIsGroupMode_, bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN1_, ck_tile::index_t kN1_,
bool kStoreLse_, bool kStoreLse_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
...@@ -678,7 +729,6 @@ struct fmha_fwd_splitkv_combine_traits_ ...@@ -678,7 +729,6 @@ struct fmha_fwd_splitkv_combine_traits_
static constexpr ck_tile::index_t HDim = HDim_; static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>; using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_; static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
......
...@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode, ...@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
std::string k_val, std::string k_val,
std::string k_pad_val, std::string k_pad_val,
ck_tile::index_t seqlen_k_min = 0, ck_tile::index_t seqlen_k_min = 0,
bool use_kvcache = false, bool need_append_kvcache = false,
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str())) #define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
...@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode, ...@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max); std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && use_kvcache) if(1 < batch && need_append_kvcache)
{ {
// to keep the original s_k value, we always use seqlen_k_max in first batch // to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()), randints(std::next(seqlen_ks.begin()),
......
...@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) ...@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
......
...@@ -59,7 +59,7 @@ args: ...@@ -59,7 +59,7 @@ args:
-kname print kernel name or not (default:1) -kname print kernel name or not (default:1)
-prec_i input precision (default:fp16) -prec_i input precision (default:fp16)
-prec_o output precision, set auto will be the same as input (default:auto) -prec_o output precision, set auto will be the same as input (default:auto)
-prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) -prec_sm output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto) -prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
...@@ -69,7 +69,7 @@ args: ...@@ -69,7 +69,7 @@ args:
``` ```
## limitations ## limitations
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. Note that `fquant=2`, `fadd=2`, `prec_sm/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
``` ```
# some case # some case
......
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation # generate kernel instances to speed up compilation
import argparse import argparse
...@@ -23,6 +23,10 @@ def get_if_str(idx, total, lase_else = True): ...@@ -23,6 +23,10 @@ def get_if_str(idx, total, lase_else = True):
else: else:
return 'else if' return 'else if'
XBIAS_ENUM_STR_MAP = [
'no',
'xbias'] # pre-norm add bias
FUSED_ADD_ENUM_STR_MAP = [ FUSED_ADD_ENUM_STR_MAP = [
'no', 'no',
'pras', # pre-norm 'pras', # pre-norm
...@@ -35,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [ ...@@ -35,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
DATA_TYPE_MAP = {'fp32' : 'float', DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::fp16_t', 'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t', 'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'} 'int8' : 'ck_tile::int8_t',
'fp8' : 'ck_tile::fp8_t'}
def BOOL_MAP(b_) -> str: def BOOL_MAP(b_) -> str:
if b_: if b_:
...@@ -48,7 +53,7 @@ class layernorm_fwd_codegen: ...@@ -48,7 +53,7 @@ class layernorm_fwd_codegen:
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_, template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N ck_tile::index_t Repeat_N_, // each thread repeat along N
...@@ -58,14 +63,16 @@ template <typename XDataType_, ...@@ -58,14 +63,16 @@ template <typename XDataType_,
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
ck_tile::index_t kXbias_ = 0,
ck_tile::index_t kFusedAdd_ = 0, ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0> ck_tile::index_t kFusedQuant_ = 0>
struct layernorm2d_fwd_traits_ struct layernorm2d_fwd_traits_
{ {
using XDataType = ck_tile::remove_cvref_t<XDataType_>; using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>; using YDataType = ck_tile::remove_cvref_t<YDataType_>;
using XScaleDataType = ck_tile::remove_cvref_t<XScaleDataType_>; using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>; using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
...@@ -120,14 +127,16 @@ struct layernorm2d_fwd_traits_ ...@@ -120,14 +127,16 @@ struct layernorm2d_fwd_traits_
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kXbias = kXbias_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
}; };
template <typename XDataType_, template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N ck_tile::index_t Repeat_N_, // each thread repeat along N
...@@ -137,12 +146,14 @@ template <typename XDataType_, ...@@ -137,12 +146,14 @@ template <typename XDataType_,
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
int kXbias_,
int kFusedAdd_, int kFusedAdd_,
int kFusedQuant_> int kFusedQuant_>
using traits_ = layernorm2d_fwd_traits_<XDataType_, using traits_ = layernorm2d_fwd_traits_<XDataType_,
YDataType_, YDataType_,
XScaleDataType_, SmoothScaleDataType_,
YScaleDataType_, YScaleDataType_,
Repeat_M_, Repeat_M_,
Repeat_N_, Repeat_N_,
...@@ -152,13 +163,15 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_, ...@@ -152,13 +163,15 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
kPadN_, kPadN_,
kSaveMeanInvStd_, kSaveMeanInvStd_,
kFastFDiv_, kFastFDiv_,
kWelford_,
kTwoPass_, kTwoPass_,
kXbias_,
kFusedAdd_, kFusedAdd_,
kFusedQuant_>; kFusedQuant_>;
""" """
API_COMMON_HEADER = """ API_COMMON_HEADER = """
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
...@@ -177,26 +190,29 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -177,26 +190,29 @@ float layernorm2d_fwd_(const S& s, A a)
{{ {{
using XDataType = typename Traits_::XDataType; using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType; using YDataType = typename Traits_::YDataType;
using XScaleDataType = typename Traits_::XScaleDataType; using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
using YScaleDataType = typename Traits_::YScaleDataType; using YScaleDataType = typename Traits_::YScaleDataType;
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType; using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN, using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd, Traits_::kSaveMeanInvStd,
Traits_::kFastFDiv, Traits_::kFastFDiv,
Traits_::kWelford,
Traits_::kTwoPass, Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dXBiasEnum>(Traits_::kXbias),
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd), static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>; static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XBiasDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::MeanDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::InvStdDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::MeanDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XScaleDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvStdDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YScaleDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
typename Traits_::Shape, typename Traits_::Shape,
PipelineTraits>; PipelineTraits>;
...@@ -204,12 +220,13 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -204,12 +220,13 @@ float layernorm2d_fwd_(const S& s, A a)
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>; using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>; using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>; using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, true>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>; using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape, static constexpr bool UseRawStore = sizeof(YDataType) == 4;
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>; using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
...@@ -233,7 +250,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -233,7 +250,7 @@ float layernorm2d_fwd_(const S& s, A a)
API_BASE = """ API_BASE = """
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
...@@ -269,12 +286,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -269,12 +286,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
INSTANCE_BASE = """ INSTANCE_BASE = """
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_api_common.hpp" #include "layernorm2d_fwd_api_common.hpp"
// clang-format off // clang-format off
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep // prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p xbias add sweep
{F_instance_def} {F_instance_def}
// clang-format on // clang-format on
...@@ -284,6 +301,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -284,6 +301,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
self.working_path = working_path self.working_path = working_path
self.kernel_filter = kernel_filter self.kernel_filter = kernel_filter
class k_xbias_enum(IntEnum):
F_NO_XBIAS = 0
F_ADD_XBIAS = 1
class k_fuesd_add_enum(IntEnum): class k_fuesd_add_enum(IntEnum):
F_NO_ADD = 0 F_NO_ADD = 0
F_PRE_ADD = 1 F_PRE_ADD = 1
...@@ -299,6 +320,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -299,6 +320,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_kPadN : bool F_kPadN : bool
F_kSaveMeanInvStd : bool F_kSaveMeanInvStd : bool
F_kTwoPass : bool F_kTwoPass : bool
F_kXbias : Any #: layernorm_fwd_codegen.k_bias_enum
F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum
F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum
...@@ -315,6 +337,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -315,6 +337,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@dataclass @dataclass
class k_problem: class k_problem:
F_XDataType : str F_XDataType : str
F_XBiasDataType : str
F_GammaDataType : str F_GammaDataType : str
F_BetaDataType : str F_BetaDataType : str
F_ComputeDataType : str F_ComputeDataType : str
...@@ -352,7 +375,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -352,7 +375,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
class h_traits: class h_traits:
F_XDataType : str F_XDataType : str
F_YDataType : str F_YDataType : str
F_XScaleDataType : str F_SmoothScaleDataType : str
F_YScaleDataType : str F_YScaleDataType : str
F_Repeat_M : int F_Repeat_M : int
F_Repeat_N : int F_Repeat_N : int
...@@ -362,15 +385,17 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -362,15 +385,17 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_kPadN : bool F_kPadN : bool
F_kSaveMeanInvStd_ : bool F_kSaveMeanInvStd_ : bool
F_kFastFDiv_ : bool F_kFastFDiv_ : bool
F_kWelford_ : bool
F_kTwoPass_ : bool F_kTwoPass_ : bool
F_kXbias_ : int
F_kFusedAdd : int F_kFusedAdd : int
F_kFusedQuant : int F_kFusedQuant : int
@property @property
def trait_name(self) ->str: def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_ return t_
# string when calling this kernel # string when calling this kernel
...@@ -388,6 +413,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -388,6 +413,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
class h_instance: class h_instance:
F_DataTypePair : str F_DataTypePair : str
F_N : str F_N : str
F_xbias : int
F_add : int F_add : int
F_sweep : int F_sweep : int
instance_list : List[Any] # List[h_traits] instance_list : List[Any] # List[h_traits]
...@@ -397,6 +423,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -397,6 +423,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
prec_i, prec_o = self.F_DataTypePair.split(',') prec_i, prec_o = self.F_DataTypePair.split(',')
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}' nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}'
if self.F_xbias != 0:
nnn = nnn + '_' + XBIAS_ENUM_STR_MAP[self.F_xbias]
if self.F_add != 0: if self.F_add != 0:
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
if self.F_sweep != 0: if self.F_sweep != 0:
...@@ -422,11 +450,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -422,11 +450,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
def name_common_header(self) -> str: def name_common_header(self) -> str:
return 'layernorm2d_fwd_api_common' return 'layernorm2d_fwd_api_common'
@property def content_api(self, args) -> str:
def content_api(self) -> str:
# 1 sort based on dtype # 1 sort based on dtype
t_dtype_dict = dict() t_dtype_dict = dict()
blobs = self.get_blobs() blobs = self.get_blobs(args)
for blob in blobs: for blob in blobs:
if blob.F_DataTypePair not in t_dtype_dict: if blob.F_DataTypePair not in t_dtype_dict:
t_dtype_dict[blob.F_DataTypePair] = {} t_dtype_dict[blob.F_DataTypePair] = {}
...@@ -451,19 +478,19 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -451,19 +478,19 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
if ins.F_kFusedQuant == 0: if ins.F_kFusedQuant == 0:
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
elif ins.F_kFusedQuant == 1: elif ins.F_kFusedQuant == 1:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType) f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType)
elif ins.F_kFusedQuant == 2: elif ins.F_kFusedQuant == 2:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( _cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd,
f_sweep_cond = _sweep_cond) f_sweep_cond = _sweep_cond)
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name) F_VEC_COND = _cond, F_instance_func=ins.call_name)
#inner_str = inner_str + vec_str #inner_str = inner_str + vec_str
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
prec_i, prec_o = dtype_.split(',') prec_i, prec_o = dtype_.split(',')
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
...@@ -474,77 +501,80 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -474,77 +501,80 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
def content_common_header(self) -> str: def content_common_header(self) -> str:
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
def get_blobs(self): def get_blobs(self, args):
h_traits = layernorm_fwd_codegen.h_traits h_traits = layernorm_fwd_codegen.h_traits
h_instance = layernorm_fwd_codegen.h_instance h_instance = layernorm_fwd_codegen.h_instance
dynamic_quant_out_dtype = ['int8'] dynamic_quant_out_dtype = ['int8', 'fp8']
# some predefined support range # some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict # (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list = [('fp32,fp32')] scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'), dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out ('fp16,int8'), ('bf16,int8'),
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out
types_8bit = ('int8', 'fp8')
types_16bit = ('int16', 'fp16', 'bf16')
#fused_add_list = [0, 1, 2] #fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
xbias_list = [0, 1]
fused_add_list = [0, 1] fused_add_list = [0, 1]
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep
# rm rn tm tn vn pd mv fdiv 2p add sweep h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0),
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)], '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0),
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)], '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)], '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)], '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)], '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0),
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)], '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)], '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)], '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0),
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)], '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)], '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)], '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)], 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0, 0),
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]}
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]}
total_blob = list() total_blob = list()
for hs_key in h_trait_dict: for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key] hs = h_trait_dict[hs_key]
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list):
prec_i, prec_o = dtype.split(',') prec_i, prec_o = dtype.split(',')
scale_x, scale_y = scale_type.split(',') scale_sm, scale_y = scale_type.split(',')
if prec_o in dynamic_quant_out_dtype and fused_quant != 1: if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
continue # skip non dynamic quant case continue # skip non dynamic quant case
if fused_quant == 1 and hs_key == 'big': if fused_quant == 1 and hs_key == 'big':
...@@ -554,20 +584,32 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -554,20 +584,32 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_ = copy.copy(chs_) # copy the base instance out h_ = copy.copy(chs_) # copy the base instance out
h_.F_XDataType = prec_i h_.F_XDataType = prec_i
h_.F_YDataType = prec_o h_.F_YDataType = prec_o
h_.F_XScaleDataType = scale_y h_.F_SmoothScaleDataType = scale_sm
h_.F_YScaleDataType = scale_x h_.F_YScaleDataType = scale_y
h_.F_kXbias = xbias
h_.F_kFusedAdd = fused_add h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant h_.F_kFusedQuant = fused_quant
# disable welford update for 8bit and 16 bit smallN
if not h_.F_kTwoPass_:
#disable 16 bit when set args disable_16b_welford
if args.disable_16b_welford and prec_i in types_16bit:
h_.F_kWelford_ = False
#disable 8bit by default
elif prec_i in types_8bit or prec_o in types_8bit:
h_.F_kWelford_ = False
#disable 16bit small N
elif prec_i in types_16bit and hs_key == '64':
h_.F_kWelford_ = False
current_hs.append(h_) # + "\n" current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) total_blob.append(h_instance(dtype, current_n_str, xbias, fused_add, fused_quant, current_hs))
return total_blob return total_blob
def list_blobs(self) -> None: def list_blobs(self, args) -> None:
w_p = Path(self.working_path) w_p = Path(self.working_path)
list_p = w_p / 'layernorm2d_fwd_blobs.txt' list_p = w_p / 'layernorm2d_fwd_blobs.txt'
blobs = self.get_blobs() blobs = self.get_blobs(args)
with list_p.open('w') as list_f: with list_p.open('w') as list_f:
# api related file # api related file
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
...@@ -576,11 +618,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -576,11 +618,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
for b in blobs: for b in blobs:
list_f.write(str(w_p / (b.name + ".cpp")) + "\n") list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
def gen_blobs(self) -> None: def gen_blobs(self, args) -> None:
w_p = Path(self.working_path) w_p = Path(self.working_path)
(w_p / (self.name_api + ".cpp")).write_text(self.content_api) w_str = self.content_api(args)
(w_p / (self.name_api + ".cpp")).write_text(w_str)
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
blobs = self.get_blobs() blobs = self.get_blobs(args)
for b in blobs: for b in blobs:
(w_p / (b.name + ".cpp")).write_text(b.content) (w_p / (b.name + ".cpp")).write_text(b.content)
...@@ -588,14 +631,14 @@ def list_blobs(args): ...@@ -588,14 +631,14 @@ def list_blobs(args):
api_list = args.api.split(',') api_list = args.api.split(',')
for api in api_list: for api in api_list:
if api == 'fwd': if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs() layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args)
def gen_blobs(args): def gen_blobs(args):
api_list = args.api.split(',') api_list = args.api.split(',')
for api in api_list: for api in api_list:
if api == 'fwd': if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs() layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -663,6 +706,13 @@ if __name__ == "__main__": ...@@ -663,6 +706,13 @@ if __name__ == "__main__":
help="codegen receipt." help="codegen receipt."
) )
parser.add_argument(
"--disable_16b_welford",
default=False,
required=False,
help="enable/disable welford for 16bit datatype n > 64"
)
args = parser.parse_args() args = parser.parse_args()
# print(f'{args.list_blobs}-{args.gen_blobs}') # print(f'{args.list_blobs}-{args.gen_blobs}')
......
...@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>() ...@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>()
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template <>
auto get_elimit<ck_tile::int8_t>()
{
double rtol = 1e-2;
double atol = 1.0;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -35,12 +43,13 @@ auto create_args(int argc, char* argv[]) ...@@ -35,12 +43,13 @@ auto create_args(int argc, char* argv[])
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec_i", "fp16", "input precision") .insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input") .insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert("prec_sx", .insert("prec_sm",
"auto", "auto",
"output quant scale type, set auto will use fp32. used when fquant=1") "output quant scale type, set auto will use fp32. used when fquant=1")
.insert("prec_sy", .insert("prec_sy",
"auto", "auto",
"output quant scale type, set auto will use fp32. used when fquant=1 or 2") "output quant scale type, set auto will use fp32. used when fquant=1 or 2")
.insert("xbias", "0", "add bias, 0:no add, 1:add bias before fadd")
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
...@@ -52,7 +61,7 @@ auto create_args(int argc, char* argv[]) ...@@ -52,7 +61,7 @@ auto create_args(int argc, char* argv[])
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename XScaleDataType, typename SmoothScaleDataType,
typename YScaleDataType, typename YScaleDataType,
bool SaveMeanVar> bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
...@@ -74,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -74,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx"); std::string prec_sm = arg_parser.get_str("prec_sm");
std::string prec_sy = arg_parser.get_str("prec_sy"); std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto") if(prec_o == "auto")
{ {
prec_o = prec_i; prec_o = prec_i;
} }
if(prec_sx == "auto") if(prec_sm == "auto")
{ {
prec_sx = "fp32"; prec_sm = "fp32";
} }
if(prec_sy == "auto") if(prec_sy == "auto")
{ {
...@@ -93,20 +102,25 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -93,20 +102,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
int xbias = arg_parser.get_int("xbias");
int fused_add = arg_parser.get_int("fadd"); int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant"); int fused_quant = arg_parser.get_int("fquant");
if(fused_quant == 1 && prec_o != "int8") if(fused_quant == 1 && prec_o != "int8" && prec_o != "fp8")
{ {
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; std::cout
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
<< std::endl;
return false; return false;
} }
assert(x_stride >= n); assert(x_stride >= n);
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>; using TypeConfig =
LayerNormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType; using YDataType = typename TypeConfig::YDataType;
using XBiasDataType = typename TypeConfig::XBiasDataType;
using GammaDataType = typename TypeConfig::GammaDataType; using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType; using BetaDataType = typename TypeConfig::BetaDataType;
using XResidualDataType = XDataType; using XResidualDataType = XDataType;
...@@ -121,6 +135,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -121,6 +135,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XBiasDataType> x_bias_host({n});
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n}); ck_tile::HostTensor<BetaDataType> beta_host({n});
...@@ -135,30 +150,33 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -135,30 +150,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m}); ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m}); ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
ck_tile::HostTensor<XScaleDataType> x_scale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host({n});
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n}); ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host_dev({n});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host); ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_bias_buf(x_bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
x_bias_buf.ToDevice(x_bias_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data()); x_residual_buf.ToDevice(x_residual_host.data());
x_scale_buf.ToDevice(x_scale_host.data()); sm_scale_buf.ToDevice(sm_scale_host.data());
auto prec_str = [&]() { auto prec_str = [&]() {
auto base_str = prec_i; auto base_str = prec_i;
...@@ -179,11 +197,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -179,11 +197,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< ", yr_stride:" << yr_stride << std::flush; << ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{ layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; prec_i, prec_o, prec_sm, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr,
x_bias_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
...@@ -210,8 +229,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,8 +229,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XBiasDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; sizeof(GammaDataType) * n + sizeof(BetaDataType) * n +
sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
...@@ -221,6 +241,22 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -221,6 +241,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
// reference // reference
if(xbias != 0)
{
// add bias before fadd
int M = x_host.mDesc.get_lengths()[0];
int N = x_host.mDesc.get_lengths()[1];
for(int idx_m = 0; idx_m < M; ++idx_m)
{
for(int idx_n = 0; idx_n < N; ++idx_n)
{
x_host(idx_m, idx_n) = ck_tile::type_convert<XDataType>(
ck_tile::type_convert<ComputeDataType>(x_host(idx_m, idx_n)) +
ck_tile::type_convert<ComputeDataType>(x_bias_host(idx_n)));
}
}
}
if(fused_add != 0) if(fused_add != 0)
{ {
// fused pre_add/pre_add_store // fused pre_add/pre_add_store
...@@ -254,8 +290,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -254,8 +290,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
// input smooth outlier // input smooth outlier
acc_(m_, n_) = acc_(m_, n_) = acc_(m_, n_) *
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_)); ck_tile::type_convert<ComputeDataType>(sm_scale_host(n_));
} }
} }
ComputeDataType absmax = static_cast<ComputeDataType>(0); ComputeDataType absmax = static_cast<ComputeDataType>(0);
...@@ -265,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -265,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax = a > absmax ? a : absmax; absmax = a > absmax ? a : absmax;
} }
// printf("cpu:absmax:%f\n", absmax); // printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0); constexpr ComputeDataType kMaxY =
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
: 0.0;
ComputeDataType y_scale = absmax / kMaxY;
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale); y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
...@@ -308,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -308,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf.FromDevice(y_residual_host_dev.data()); y_residual_buf.FromDevice(y_residual_host_dev.data());
} }
auto [rtol, atol] = get_elimit<InDataType>(); auto [rtol, atol] = get_elimit<OutDataType>();
if(x_stride == n) if(x_stride == n)
{ {
...@@ -377,16 +417,16 @@ int main(int argc, char* argv[]) ...@@ -377,16 +417,16 @@ int main(int argc, char* argv[])
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx"); std::string prec_sm = arg_parser.get_str("prec_sm");
std::string prec_sy = arg_parser.get_str("prec_sy"); std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto") if(prec_o == "auto")
{ {
prec_o = prec_i; prec_o = prec_i;
} }
if(prec_sx == "auto") if(prec_sm == "auto")
{ {
prec_sx = "fp32"; prec_sm = "fp32";
} }
if(prec_sy == "auto") if(prec_sy == "auto")
{ {
...@@ -395,37 +435,47 @@ int main(int argc, char* argv[]) ...@@ -395,37 +435,47 @@ int main(int argc, char* argv[])
int save_mv = arg_parser.get_int("save_mv"); int save_mv = arg_parser.get_int("save_mv");
// no dynamic quant case // no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv) if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv)
{ {
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
save_mv) save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
} }
// dynamic quant case, only in inference // dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
return -3; return -3;
} }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -8,35 +8,40 @@ ...@@ -8,35 +8,40 @@
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/layernorm2d.hpp"
#include <string> #include <string>
template <typename InType, typename OutType, typename XScaleDataType_, typename YScaleDataType_> template <typename InType,
typename OutType,
typename SmoothSScaleDataType_,
typename YScaleDataType_>
struct LayerNormTypeConfig; struct LayerNormTypeConfig;
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleDataType_> struct LayerNormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = OutType; using YDataType = OutType;
using GammaDataType = ck_tile::half_t; using XBiasDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t;
using ComputeDataType = float; using InvStdDataType = ck_tile::half_t;
using XScaleDataType = XScaleDataType_; using ComputeDataType = float;
using YScaleDataType = YScaleDataType_; using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleDataType_> struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = OutType; using YDataType = OutType;
using GammaDataType = ck_tile::bf16_t; using XBiasDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t;
using InvStdDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t;
using ComputeDataType = float; using InvStdDataType = ck_tile::bf16_t;
using XScaleDataType = XScaleDataType_; using ComputeDataType = float;
using YScaleDataType = YScaleDataType_; using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
// runtime args // runtime args
...@@ -50,13 +55,14 @@ struct layernorm2d_fwd_traits ...@@ -50,13 +55,14 @@ struct layernorm2d_fwd_traits
std::string prec_i; // input precision std::string prec_i; // input precision
std::string prec_o; // output precision std::string prec_o; // output precision
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set // if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check) // can set arbitrary(will skip check)
std::string prec_sx; // x-scale, used for [1*N] input smooth quant std::string prec_sm; // x-scale, used for [1*N] input smooth quant
std::string prec_sy; // y-scale, used for [M*1] output for next layer std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_mean_var; // bool save_mean_var; //
int xbias; // 0:no-bias, 1:add bias
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
}; };
......
#!/bin/sh #!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8"; do for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=1 -prec_o=fp8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
...@@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 ...@@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=9120
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done done
done done
......
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
...@@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ <arch> ...@@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation # The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation # The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j make tile_example_gemm_universal -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal`
## example ## example
``` ```
...@@ -22,6 +22,9 @@ args: ...@@ -22,6 +22,9 @@ args:
-m m dimension (default:1024) -m m dimension (default:1024)
-n n dimension (default:2048) -n n dimension (default:2048)
-k k dimension (default:64) -k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: R)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0) -stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0) -stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0) -stride_c Tensor C stride (default:0)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -9,29 +9,29 @@ ...@@ -9,29 +9,29 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ADataType,
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false; constexpr bool kPadM = false;
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
// This part comes from the Codegen // This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t N_Warp = 2;
...@@ -39,59 +39,47 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -39,59 +39,47 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 16;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape = using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenGemmPipeline = using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>; ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKernelArgs(args);
args.p_b,
args.p_c, const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
...@@ -108,4 +96,46 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -108,4 +96,46 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -8,6 +8,27 @@ ...@@ -8,6 +8,27 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
...@@ -22,6 +43,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t> ...@@ -22,6 +43,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM. // ToDo: Add more bias config to support different categories of GEMM.
}; };
template <>
struct GemmBasicTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T> template <typename T>
struct DataTypeTraits; struct DataTypeTraits;
...@@ -43,37 +91,32 @@ struct DataTypeTraits<ck_tile::half_t> ...@@ -43,37 +91,32 @@ struct DataTypeTraits<ck_tile::half_t>
static constexpr const char* name = "fp16"; static constexpr const char* name = "fp16";
}; };
using Types = GemmBasicTypeConfig<ck_tile::half_t>; template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
// Specific type aliases for easy access template <>
using ADataType = Types::ADataType; struct DataTypeTraits<ck_tile::fp8_t>
using BDataType = Types::BDataType; {
using AccDataType = Types::AccDataType; static constexpr const char* name = "fp8";
using CDataType = Types::CDataType; };
struct gemm_basic_args template <>
struct DataTypeTraits<ck_tile::bf8_t>
{ {
const void* p_a; static constexpr const char* name = "bf8";
const void* p_b;
void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
}; };
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("m", "3840", "m dimension")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension") .insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride") .insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
...@@ -82,11 +125,12 @@ auto create_args(int argc, char* argv[]) ...@@ -82,11 +125,12 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
// host API // host API
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
template <typename ALayout, typename BLayout, typename CLayout> template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType,
typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf,
...@@ -16,11 +45,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -16,11 +45,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
gemm_basic_args args; ck_tile::GemmHostArgs args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch; args.k_batch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
...@@ -28,26 +57,31 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -28,26 +57,31 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B; args.stride_B = stride_B;
args.stride_C = stride_C; args.stride_C = stride_C;
float ave_time = gemm_calc<ALayout, BLayout, CLayout>( float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Gemm{MemBoundPipeline}"};
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name
<< " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl; << std::endl;
return ave_time; return ave_time;
} }
template <typename ALayout, typename BLayout, typename CLayout> template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
int run_gemm_example_with_layouts(int argc, int run_gemm_example_with_layouts(int argc,
char* argv[], char* argv[],
const ALayout a_layout = ALayout{}, const ALayout a_layout = ALayout{},
...@@ -58,6 +92,11 @@ int run_gemm_example_with_layouts(int argc, ...@@ -58,6 +92,11 @@ int run_gemm_example_with_layouts(int argc,
if(!result) if(!result)
return -1; return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k"); ck_tile::index_t K = arg_parser.get_int("k");
...@@ -66,55 +105,22 @@ int run_gemm_example_with_layouts(int argc, ...@@ -66,55 +105,22 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
using namespace ck_tile::literals; stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
auto f_host_tensor_descriptor = stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>) ck_tile::HostTensor<ADataType> a_m_k(
{ ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); ck_tile::HostTensor<BDataType> b_k_n(
} ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, a_layout));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, b_layout));
ck_tile::HostTensor<CDataType> c_m_n_dev_result( ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
// TODO: add different init types // TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n); ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
...@@ -127,7 +133,8 @@ int run_gemm_example_with_layouts(int argc, ...@@ -127,7 +133,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf, invoke_gemm<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf, b_k_n_dev_buf,
c_m_n_dev_buf, c_m_n_dev_buf,
M, M,
...@@ -136,7 +143,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -136,7 +143,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
batch_size, kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
...@@ -146,72 +153,84 @@ int run_gemm_example_with_layouts(int argc, ...@@ -146,72 +153,84 @@ int run_gemm_example_with_layouts(int argc,
if(arg_parser.get_int("v") == 1) if(arg_parser.get_int("v") == 1)
{ {
ck_tile::HostTensor<CDataType> c_m_n_host_ref( ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref); a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref( ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType, ck_tile::reference_gemm_gpu<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
ALayout, ALayout,
BLayout, BLayout,
CLayout>( CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); d_C,
M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
} }
return pass; return pass;
} }
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
#!/bin/sh
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
$EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
#!/bin/sh
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
\ No newline at end of file
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment