Unverified Commit 6df5fe2a authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE]naive attn support FP8 KVCache quant (#1747)



* quant

* fix bug

* simple smoothquant after softmax

* update kv-quant

* update stride

* fix fp8-pertoken-kvcache

* update int8/fp8 quant support

---------

Co-authored-by: so <a.com>
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent 4f62f6e9
...@@ -1131,15 +1131,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1131,15 +1131,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
// NOTE: use gpu to do validation // NOTE: use gpu to do validation
ck_tile::naive_attention_fwd_traits naive_t; ck_tile::naive_attention_fwd_traits naive_t;
naive_t.q_type = data_type; naive_t.q_type = data_type;
naive_t.k_type = data_type; naive_t.k_type = data_type;
naive_t.v_type = data_type; naive_t.v_type = data_type;
naive_t.o_type = data_type; naive_t.o_type = data_type;
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
naive_t.variation = 0; // TODO? naive_t.variation = 0; // TODO?
naive_t.quant_algo = 0;
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
......
...@@ -13,13 +13,18 @@ namespace ck_tile { ...@@ -13,13 +13,18 @@ namespace ck_tile {
enum class naive_attention_layout_enum enum class naive_attention_layout_enum
{ {
BSHD, // [batch, seqlen, nhead, hdim] DEFAULT, // maybe this tensor is not used, set some irrelevant value
BHSD, // [batch, nhead, seqlen, hdim] BSHD, // [batch, seqlen, nhead, hdim]
BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed BHSD, // [batch, nhead, seqlen, hdim]
PHSD, // [pages, nhead, page_size, hdim] BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
PHSD, // [pages, nhead, page_size, hdim]
// PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen // PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
PHDS, // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen PHDS, // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
// scale layout used for dynamic dequant
SCALE_HS, // [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant
SCALE_SH, // [tokens, nhead]
}; };
// will used to specialize kernel variation // will used to specialize kernel variation
...@@ -30,6 +35,15 @@ enum class naive_attention_variation_enum ...@@ -30,6 +35,15 @@ enum class naive_attention_variation_enum
DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache
}; };
enum class naive_attention_quant_algo
{
NO = 0,
KV_8BIT_PERHEAD = 1,
// FP8/INT8 quant for KVCache, per-token quant
// [num_tokens, nhead, hdim] -> [nhead, num_tokens]
KV_8BIT_PERTOKEN = 2,
};
// TODO: for simplicity, this will be used as host/device arg // TODO: for simplicity, this will be used as host/device arg
struct naive_attention_fwd_args struct naive_attention_fwd_args
{ {
...@@ -40,7 +54,8 @@ struct naive_attention_fwd_args ...@@ -40,7 +54,8 @@ struct naive_attention_fwd_args
void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a
// number, not cumsum) // number, not cumsum)
void* page_table_ptr; // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn) void* page_table_ptr; // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
void* kvscale_ptr; // [nhead, 2(kv), hdim] used for kvcache dequant void* kscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
void* vscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
float scale_s; float scale_s;
int hdim; int hdim;
int hdim_v; // could be cross-attn, where V and Q/K hdim are different int hdim_v; // could be cross-attn, where V and Q/K hdim are different
...@@ -54,6 +69,7 @@ struct naive_attention_fwd_args ...@@ -54,6 +69,7 @@ struct naive_attention_fwd_args
int nhead_ratio_kv; // nhead_q / nhead_kv int nhead_ratio_kv; // nhead_q / nhead_kv
int page_size; // if paged, the seqlen-kv per each block int page_size; // if paged, the seqlen-kv per each block
int max_pages_per_seq; int max_pages_per_seq;
int max_kv_tokens; // used as stride to access kv scale ptr
}; };
// this is trait for host API // this is trait for host API
...@@ -67,14 +83,16 @@ struct naive_attention_fwd_traits ...@@ -67,14 +83,16 @@ struct naive_attention_fwd_traits
std::string k_layout; std::string k_layout;
std::string v_layout; std::string v_layout;
std::string o_layout; std::string o_layout;
int variation; // sync with naive_attention_variation_enum int variation; // sync with naive_attention_variation_enum
int quant_algo; // sync with naive_attention_quant_algo
}; };
// this is trait for kernel template // this is trait for kernel template
template <naive_attention_variation_enum variation_> template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
struct naive_attention_fwd_kernel_traits struct naive_attention_fwd_kernel_traits
{ {
static constexpr naive_attention_variation_enum variation = variation_; static constexpr naive_attention_variation_enum variation = variation_;
static constexpr naive_attention_quant_algo quant_algo = quant_algo_;
}; };
// for simplicity, please do not use const-reference type for the template type // for simplicity, please do not use const-reference type for the template type
...@@ -83,28 +101,39 @@ template <typename QType, ...@@ -83,28 +101,39 @@ template <typename QType,
typename VType, typename VType,
typename OType, typename OType,
typename AccType, typename AccType,
typename KVScaleType,
naive_attention_layout_enum QLayout, naive_attention_layout_enum QLayout,
naive_attention_layout_enum KLayout, naive_attention_layout_enum KLayout,
naive_attention_layout_enum VLayout, naive_attention_layout_enum VLayout,
naive_attention_layout_enum OLayout, naive_attention_layout_enum OLayout,
naive_attention_layout_enum KScaleLayout,
naive_attention_layout_enum VScaleLayout,
typename Traits> typename Traits>
struct naive_attention_fwd_kernel struct naive_attention_fwd_kernel
{ {
static constexpr bool is_kvcache_i8 = static constexpr bool is_kvcache_i8 =
std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t> && sizeof(QType) != 1; std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
static constexpr bool is_kvcache_fp8 =
std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
// kvcache-i8 will have per head scale, we apply this scale to Q/P matrix instead of original static constexpr int v_per_token_quant_group_size = 64;
// K/V matrix. This can speed up conversion since Q/P usually is fp16/bf16/fp32
static constexpr bool is_kvcache_i8_forward_quant = is_kvcache_i8;
// TODO: hardcode // TODO: hardcode
using KVScaleType = float; using SoftmaxType = float; // always using float to do softmax compute
using SoftmaxType = float; using QuantComputeType = float; // used for quant/dequant scale compute
using PType = VType; // src A of gemm2, same type as V using QCompute = KType; // src A of gemm1, same type as K
using PType = VType; // src A of gemm2, same type as V
using OAccType = float; // always float, in case int8 FA
using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>; using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>;
static constexpr int p_vec_elem = vector_traits<p_vec_type>::vector_size; static constexpr int p_vec_elem = vector_traits<p_vec_type>::vector_size;
// clang-format off
template <typename T_> struct scale_max { static constexpr float value = 1; /* dummy code */ };
template <> struct scale_max<int8_t> { static constexpr float value = 127.0; };
template <> struct scale_max<fp8_t> { static constexpr float value = 240.0; };
// clang-format on
__host__ __device__ naive_attention_fwd_kernel() {} __host__ __device__ naive_attention_fwd_kernel() {}
template <typename T, naive_attention_layout_enum Layout> template <typename T, naive_attention_layout_enum Layout>
...@@ -198,24 +227,31 @@ struct naive_attention_fwd_kernel ...@@ -198,24 +227,31 @@ struct naive_attention_fwd_kernel
__device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {} __device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
}; };
template <typename T> template <typename T, naive_attention_layout_enum Layout>
struct kvscale_addresser struct kvscale_addresser
{ {
int h, d; // nhead, hdim int s, h, d; // seqlen(tokens), nhead, hdim
T* base_ptr; T* base_ptr;
__device__ kvscale_addresser(int h_, int d_, void* p_) __device__ kvscale_addresser(int s_, int h_, int d_, void* p_)
: h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_)) : s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
{ {
} }
__device__ int get_offset(int i_h, int i_d, int i_kv /*0 or 1*/) __device__ int get_offset(int i_s, int i_h, int i_d)
{ {
if constexpr(Layout == naive_attention_layout_enum::SCALE_HS)
{
// [nhead, tokens]
(void)i_d;
return i_h * s + i_s;
}
else if constexpr(Layout == naive_attention_layout_enum::DEFAULT)
{
return 0;
}
// [h, 2, d] // [h, 2, d]
return i_h * 2 * d + i_kv * d + i_d; // return i_h * 2 * d + i_kv * d + i_d;
}
__device__ T load(int i_h, int i_d, int i_kv)
{
return base_ptr[get_offset(i_h, i_d, i_kv)];
} }
__device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
}; };
__device__ __host__ static constexpr int get_block_size() { return 256; } __device__ __host__ static constexpr int get_block_size() { return 256; }
...@@ -282,12 +318,13 @@ struct naive_attention_fwd_kernel ...@@ -282,12 +318,13 @@ struct naive_attention_fwd_kernel
__device__ void operator()(naive_attention_fwd_args args) __device__ void operator()(naive_attention_fwd_args args)
{ {
constexpr int wg_size = get_block_size(); constexpr int wg_size = get_block_size();
__shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough __shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough
int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v char* smem_quant_q = smem + wg_size * 2 * sizeof(float); // second half, should enough
int i_sq = blockIdx.y; // index of seqlen_q int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
int i_batch = blockIdx.z; // index of batch_q * nhead_q int i_sq = blockIdx.y; // index of seqlen_q
int i_bq = i_batch / args.nhead_q; // index of batch_q int i_batch = blockIdx.z; // index of batch_q * nhead_q
int i_hq = i_batch % args.nhead_q; // index of nhead_q int i_bq = i_batch / args.nhead_q; // index of batch_q
int i_hq = i_batch % args.nhead_q; // index of nhead_q
int i_bk = i_bq / args.batch_ratio_kv; int i_bk = i_bq / args.batch_ratio_kv;
int i_hk = i_hq / args.nhead_ratio_kv; int i_hk = i_hq / args.nhead_ratio_kv;
...@@ -360,9 +397,10 @@ struct naive_attention_fwd_kernel ...@@ -360,9 +397,10 @@ struct naive_attention_fwd_kernel
auto f_max = [](auto x_, auto y_) { return max(x_, y_); }; auto f_max = [](auto x_, auto y_) { return max(x_, y_); };
auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
auto f_absmax_f32 = [](float v_0_, float v_1_) { auto f_absmax_f32 = [](float v_0_, float v_1_) {
float rtn; // float rtn;
asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_)); // asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
return rtn; // return rtn;
return max(abs(v_0_), abs(v_1_));
}; };
int seqlen_kv = [&]() { int seqlen_kv = [&]() {
...@@ -378,45 +416,82 @@ struct naive_attention_fwd_kernel ...@@ -378,45 +416,82 @@ struct naive_attention_fwd_kernel
SoftmaxType row_max = -numeric<SoftmaxType>::infinity(); SoftmaxType row_max = -numeric<SoftmaxType>::infinity();
SoftmaxType l{0}; SoftmaxType l{0};
AccType o_acc = {0}; // AccType o_acc = {0};
OAccType o_acc = {0};
int sk_loops = (seqlen_kv + wg_size - 1) / wg_size; int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
float qf_scale = .0f; QuantComputeType q_dequant_scale = .0f;
kvscale_addresser<KVScaleType> kvscale_addr{args.nhead_kv, args.hdim, args.kvscale_ptr}; kvscale_addresser<KVScaleType, KScaleLayout> kscale_addr{
args.max_kv_tokens, args.nhead_kv, args.hdim, args.kscale_ptr};
kvscale_addresser<KVScaleType, VScaleLayout> vscale_addr{
args.max_kv_tokens, args.nhead_kv, args.hdim_v, args.vscale_ptr};
if constexpr(is_kvcache_i8_forward_quant) if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{ {
// AccType is i32 now, seqlen_q = 1, hdim up to 256 // AccType is i32 now, seqlen_q = 1, hdim up to 256
float q = 0; AccType q = 0;
float k_s = 0; AccType k_s = 0;
if(static_cast<int>(threadIdx.x) < args.hdim) if(static_cast<int>(threadIdx.x) < args.hdim)
{ {
q = type_convert<float>(q_addr.load(0, threadIdx.x)); q = type_convert<AccType>(q_addr.load(0, threadIdx.x));
k_s = type_convert<float>(kvscale_addr.load(i_hk, threadIdx.x, 0)); k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0));
} }
// 1) we apply the k scale to q // 1) we apply the k scale to q
float q_forwarded = q * k_s; AccType q_forwarded = q * k_s;
// 2) apply smooth-quant // 2) apply smooth-quant
// find absmax // find absmax
float qf_max = wave_reduce(q_forwarded, f_absmax_f32); AccType qf_max = wave_reduce(q_forwarded, f_absmax_f32);
qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<float*>(smem)); qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType*>(smem));
// per-token scale // per-token scale
qf_scale = qf_max / 127.0; q_dequant_scale = type_convert<QuantComputeType>(qf_max) / scale_max<QCompute>::value;
// devide by scale // devide by scale
q = q / qf_scale; q = q / q_dequant_scale;
// fp32->i8 // fp32->i8
int8_t quantized_q = static_cast<int8_t>(q); QCompute quantized_q = static_cast<QCompute>(q);
__syncthreads(); __syncthreads();
reinterpret_cast<int8_t*>(smem)[threadIdx.x] = quantized_q; reinterpret_cast<QCompute*>(smem)[threadIdx.x] = quantized_q;
__syncthreads(); __syncthreads();
// after above process, we have 2 data // after above process, we have 2 data
// 1) int8 q data stored in smem(no need to reload) // 1) int8 q data stored in smem(no need to reload)
// 2) per-token scale qf_scale, to be mul after 1st gemm // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
{
// dyanmic quant q here
float q = 0;
if(static_cast<int>(threadIdx.x) < args.hdim)
{
q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
}
// apply smooth-quant
// find absmax
float q_max = wave_reduce(q, f_absmax_f32);
q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));
// per-token scale
q_dequant_scale =
type_convert<QuantComputeType>(q_max) / scale_max<QCompute>::value;
// devide by scale
q = q / q_dequant_scale;
QCompute quantized_q = type_convert<QCompute>(q);
__syncthreads();
reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
__syncthreads();
// after above process, we have 2 data
// 1) fp8 q data stored in smem(no need to reload from global)
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
} }
for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++) for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
...@@ -429,33 +504,41 @@ struct naive_attention_fwd_kernel ...@@ -429,33 +504,41 @@ struct naive_attention_fwd_kernel
AccType s_acc{0}; // clear for every loop AccType s_acc{0}; // clear for every loop
for(auto i_dq = 0; i_dq < args.hdim; i_dq++) for(auto i_dq = 0; i_dq < args.hdim; i_dq++)
{ {
if constexpr(is_kvcache_i8_forward_quant) auto q = [&]() {
{ if constexpr(Traits::quant_algo ==
int8_t q = reinterpret_cast<int8_t*>(smem)[i_dq]; naive_attention_quant_algo::KV_8BIT_PERHEAD ||
auto k = k_addr.load(i_sk, i_dq); Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k); {
} return reinterpret_cast<QCompute*>(smem_quant_q)[i_dq];
else }
{ else
auto q = q_addr.load(i_sq, i_dq); // q will have duplicate load return q_addr.load(i_sq, i_dq); // q will have duplicate load
auto k = k_addr.load(i_sk, i_dq); }();
auto k = [&]() { return k_addr.load(i_sk, i_dq); }();
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k); s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
}
} }
// scale // scale
s_softmax = type_convert<SoftmaxType>(s_acc); s_softmax = type_convert<SoftmaxType>(s_acc);
s_softmax *= s_softmax *=
type_convert<SoftmaxType>(args.scale_s * ck_tile::log2e_v<SoftmaxType>); type_convert<SoftmaxType>(args.scale_s * ck_tile::log2e_v<SoftmaxType>);
if constexpr(is_kvcache_i8_forward_quant) if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{
s_softmax *= q_dequant_scale; // post scale the per-token factor
}
else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{ {
s_softmax *= qf_scale; // post scale the per-token factor SoftmaxType k_per_token_scale =
type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
s_softmax *= q_dequant_scale;
s_softmax *= k_per_token_scale;
} }
} }
// s->p // s->p
float pf_scale = 0.; // used for i8 quant QuantComputeType p_dequant_scale = 1.;
{ {
// softmax, find max // softmax, find max
SoftmaxType old_max = row_max; SoftmaxType old_max = row_max;
...@@ -473,41 +556,69 @@ struct naive_attention_fwd_kernel ...@@ -473,41 +556,69 @@ struct naive_attention_fwd_kernel
// l, pre-scall o_acc // l, pre-scall o_acc
SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max); SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
l = tmp * l + row_sum; l = tmp * l + row_sum;
o_acc = type_convert<AccType>(type_convert<SoftmaxType>(o_acc) * tmp); o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
// prepare the p_compute into smem, to let every thread read same p_compute and do // prepare the p_compute into smem, to let every thread read same p_compute and do
// 2nd gemm // 2nd gemm
if constexpr(is_kvcache_i8_forward_quant) if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{ {
float v_s = 0; QuantComputeType v_s = 0;
if(static_cast<int>(threadIdx.x) < args.hdim_v) if(static_cast<int>(threadIdx.x) < args.hdim_v)
{ {
v_s = type_convert<float>(kvscale_addr.load(i_hk, threadIdx.x, 1)); v_s =
type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
} }
// 1) we apply the v scale to p // 1) we apply the v scale to p
float p_forwarded = p_compute * v_s; QuantComputeType p_forwarded = p_compute * v_s;
// 2) apply smooth-quant // 2) apply smooth-quant
// find absmax // find absmax
float pf_max = wave_reduce(p_forwarded, f_absmax_f32); QuantComputeType pf_max = wave_reduce(p_forwarded, f_absmax_f32);
pf_max = pf_max = cross_wave_reduce(
cross_wave_reduce(pf_max, f_absmax_f32, reinterpret_cast<float*>(smem)); pf_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
// per-token scale // per-token scale
pf_scale = pf_max / 127.0; p_dequant_scale = pf_max / scale_max<PType>::value; // 127.0;
// devide by scale // devide by scale
p_compute = p_compute / pf_scale; p_compute = p_compute / p_dequant_scale;
// fp32->i8 // fp32->i8
int8_t quantized_p = static_cast<int8_t>(p_compute); PType quantized_p = static_cast<PType>(p_compute);
__syncthreads(); __syncthreads();
reinterpret_cast<int8_t*>(smem)[threadIdx.x] = quantized_p; reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
__syncthreads(); __syncthreads();
// after above process, we have 2 data // after above process, we have 2 data
// 1) int8 p data stored in smem(no need to reload) // 1) int8 p data stored in smem(no need to reload)
// 2) per-token scale pf_scale, to be mul after 2nd gemm // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
}
else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
// forward apply the v scale to p_compute, this is compute friendly
auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
p_compute *= v_scale;
// smooth-quant
// find absmax
QuantComputeType p_max = wave_reduce(p_compute, f_absmax_f32);
p_max = cross_wave_reduce(
p_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
// per-token scale
p_dequant_scale = p_max / scale_max<PType>::value; // 240.0;
// devide by scale
p_compute = p_compute / p_dequant_scale;
// fp32->i8
PType quantized_p = type_convert<PType>(p_compute);
__syncthreads();
reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
__syncthreads();
// after above process, we have 2 data
// 1) fp8_t p data stored in smem(no need to reload)
// 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
} }
else else
{ {
...@@ -531,29 +642,45 @@ struct naive_attention_fwd_kernel ...@@ -531,29 +642,45 @@ struct naive_attention_fwd_kernel
int sv_offset = i_loop2 * p_vec_elem + i_j; int sv_offset = i_loop2 * p_vec_elem + i_j;
int i_sv = sk_start + sv_offset; int i_sv = sk_start + sv_offset;
VType v = 0.f; VType v = 0;
if(i_dv < args.hdim_v && i_sv < seqlen_kv) if(i_dv < args.hdim_v && i_sv < seqlen_kv)
{ {
v = v_addr.load(i_sv, i_dv); v = v_addr.load(i_sv, i_dv);
} }
o_acc_local += type_convert<AccType>(p_vec[i_j]) * type_convert<AccType>(v); AccType v_compute = [&]() { return type_convert<AccType>(v); }();
o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute;
} }
} }
if constexpr(is_kvcache_i8_forward_quant)
{ OAccType post_scale_o_acc_local = [&]() {
// apply pr scale to local acc if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
o_acc_local = {
type_convert<AccType>(type_convert<float>(o_acc_local) * pf_scale); // apply pr scale to local acc
} return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
o_acc += o_acc_local; p_dequant_scale);
}
else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
// apply pr scale to local acc
return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
p_dequant_scale);
}
else
{
return type_convert<OAccType>(o_acc_local);
}
}();
o_acc += post_scale_o_acc_local;
} }
} }
// post scale o_acc // post scale o_acc
{ {
SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
o_acc = type_convert<AccType>(type_convert<SoftmaxType>(o_acc) * tmp); o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
} }
// store O // store O
...@@ -564,18 +691,21 @@ struct naive_attention_fwd_kernel ...@@ -564,18 +691,21 @@ struct naive_attention_fwd_kernel
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \ #define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
{ \ { \
using ktraits_ = \ using ktraits_ = naive_attention_fwd_kernel_traits< \
naive_attention_fwd_kernel_traits<static_cast<naive_attention_variation_enum>( \ static_cast<naive_attention_variation_enum>(variation_), \
variation_)>; \ static_cast<naive_attention_quant_algo>(quant_algo_)>; \
using k_ = naive_attention_fwd_kernel<q_type_, \ using k_ = naive_attention_fwd_kernel<q_type_, \
k_type_, \ k_type_, \
v_type_, \ v_type_, \
o_type_, \ o_type_, \
acc_type_, \ acc_type_, \
kvscale_type_, \
q_layout_, \ q_layout_, \
k_layout_, \ k_layout_, \
v_layout_, \ v_layout_, \
o_layout_, \ o_layout_, \
k_scale_layout_, \
v_scale_layout_, \
ktraits_>; \ ktraits_>; \
dim3 grids = k_::get_grid_size(a); \ dim3 grids = k_::get_grid_size(a); \
r = ck_tile::launch_kernel(s, \ r = ck_tile::launch_kernel(s, \
...@@ -586,31 +716,37 @@ struct naive_attention_fwd_kernel ...@@ -586,31 +716,37 @@ struct naive_attention_fwd_kernel
if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \ if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
t.o_layout == "bshd") \ t.o_layout == "bshd") \
{ \ { \
constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \ constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \ constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \ constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \ constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
constexpr int variation_ = 0; \ constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \ CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \ } \
else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \ else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
t.v_layout == "bhsd" && t.o_layout == "bhsd") \ t.v_layout == "bhsd" && t.o_layout == "bhsd") \
{ \ { \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \ constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \ constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \ constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \ constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 0; \ constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \ CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \ } \
else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \ else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
t.v_layout == "phds" && t.o_layout == "bhsd") \ t.v_layout == "phds" && t.o_layout == "bhsd") \
{ \ { \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \ constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \ constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \ constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \ constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 2; \ constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
constexpr int variation_ = 2; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \ CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} }
...@@ -621,40 +757,64 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, ...@@ -621,40 +757,64 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
{ {
float r = -1; float r = -1;
// TODO: do not explicitly create too much instance! // TODO: do not explicitly create too much instance!
if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16") if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16" &&
t.quant_algo == 0)
{
using q_type_ = fp16_t;
using k_type_ = fp16_t;
using v_type_ = fp16_t;
using o_type_ = fp16_t;
using acc_type_ = float;
using kvscale_type_ = float;
constexpr int quant_algo_ = 0;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16" &&
t.quant_algo == 0)
{ {
using q_type_ = fp16_t; using q_type_ = bf16_t;
using k_type_ = fp16_t; using k_type_ = bf16_t;
using v_type_ = fp16_t; using v_type_ = bf16_t;
using o_type_ = fp16_t; using o_type_ = bf16_t;
using acc_type_ = float; using acc_type_ = float;
using kvscale_type_ = float;
constexpr int quant_algo_ = 0;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
} }
else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16") else if(t.q_type == "bf16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "bf16" &&
t.quant_algo == 2)
{ {
using q_type_ = bf16_t; using q_type_ = bf16_t;
using k_type_ = bf16_t; using k_type_ = fp8_t;
using v_type_ = bf16_t; using v_type_ = fp8_t;
using o_type_ = bf16_t; using o_type_ = bf16_t;
using acc_type_ = float; using acc_type_ = float; // NOTE!
using kvscale_type_ = float;
constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
} }
else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16") else if(t.q_type == "fp16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "fp16" &&
t.quant_algo == 2)
{ {
using q_type_ = bf16_t; using q_type_ = fp16_t;
using k_type_ = int8_t; using k_type_ = fp8_t;
using v_type_ = int8_t; using v_type_ = fp8_t;
using o_type_ = bf16_t; using o_type_ = fp16_t;
using acc_type_ = int32_t; // NOTE! using acc_type_ = float; // NOTE!
using kvscale_type_ = float;
constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
} }
else if(t.q_type == "fp16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "fp16") else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16" &&
t.quant_algo == 2)
{ {
using q_type_ = fp16_t; using q_type_ = bf16_t;
using k_type_ = int8_t; using k_type_ = int8_t;
using v_type_ = int8_t; using v_type_ = int8_t;
using o_type_ = fp16_t; using o_type_ = bf16_t;
using acc_type_ = int32_t; // NOTE! using acc_type_ = int32_t; // NOTE!
using kvscale_type_ = float;
constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
} }
return r; return r;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment