Commit 994989c0 authored by valarLip's avatar valarLip
Browse files

add draft int4 naive pa

parent a8c5bd9b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace ck_tile {
// 8 bit int4
struct int4x2_t
{
uint8_t raw;
CK_TILE_HOST_DEVICE constexpr int4x2_t() : raw{uint8_t{}} {}
// CK_TILE_HOST_DEVICE constexpr int4x2_t(uint8_t init) : raw{((init & 0x0f) << 4) | (init & 0x0f)}
// {
// }
};
CK_TILE_HOST_DEVICE
constexpr fp32x2_t int4x2_to_floatx2(const int4x2_t& x)
{
auto x_u8 = x.raw;
// naive implement
float x_h = ((x_u8 & 0xf0) >> 4);
if(x_h >= 8)
{
x_h -= 16;
}
float x_l = (x_u8 & 0x0f);
if(x_l >= 8)
{
x_l -= 16;
}
return {x_h, x_l};
}
CK_TILE_HOST_DEVICE
constexpr int4x2_t floatx2_to_int4x2(const fp32x2_t& x)
{
// naive implement
int4x2_t res;
auto x_l = static_cast<int8_t>(x.x);
auto x_h = static_cast<int8_t>(x.y);
res.raw = (x_l << 4) | (x_h & 0x0F);
return res;
}
} // namespace ck_tile
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp" #include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/int4.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -64,6 +65,9 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) ...@@ -64,6 +65,9 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8) CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float) CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
CK_TILE_TYPE_CONVERT(fp32x2_t, floatx2, int4x2_t, int4x2)
CK_TILE_TYPE_CONVERT(int4x2_t, int4x2, fp32x2_t, floatx2)
#undef CK_TILE_TYPE_CONVERT #undef CK_TILE_TYPE_CONVERT
#endif #endif
......
...@@ -42,6 +42,8 @@ enum class naive_attention_quant_algo ...@@ -42,6 +42,8 @@ enum class naive_attention_quant_algo
// FP8/INT8 quant for KVCache, per-token quant // FP8/INT8 quant for KVCache, per-token quant
// [num_tokens, nhead, hdim] -> [nhead, num_tokens] // [num_tokens, nhead, hdim] -> [nhead, num_tokens]
KV_8BIT_PERTOKEN = 2, KV_8BIT_PERTOKEN = 2,
// same as 8bit per token quant but 4 bit
KV_4BIT_PERTOKEN = 3,
}; };
// TODO: for simplicity, this will be used as host/device arg // TODO: for simplicity, this will be used as host/device arg
...@@ -100,7 +102,8 @@ template <typename QType, ...@@ -100,7 +102,8 @@ template <typename QType,
typename KType, typename KType,
typename VType, typename VType,
typename OType, typename OType,
typename AccType, typename AccType_I, // i.e. input of mfma
typename AccType_O, // i.e. results of mfma
typename KVScaleType, typename KVScaleType,
naive_attention_layout_enum QLayout, naive_attention_layout_enum QLayout,
naive_attention_layout_enum KLayout, naive_attention_layout_enum KLayout,
...@@ -111,20 +114,15 @@ template <typename QType, ...@@ -111,20 +114,15 @@ template <typename QType,
typename Traits> typename Traits>
struct naive_attention_fwd_kernel struct naive_attention_fwd_kernel
{ {
static constexpr bool is_kvcache_i8 =
std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
static constexpr bool is_kvcache_fp8 =
std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
static constexpr int v_per_token_quant_group_size = 64;
// TODO: hardcode // TODO: hardcode
using SoftmaxType = float; // always using float to do softmax compute using SoftmaxType = float; // always using float to do softmax compute
using QuantComputeType = float; // used for quant/dequant scale compute using QuantComputeType = float; // used for quant/dequant scale compute
using QCompute = KType; // src A of gemm1, same type as K using QCompute = AccType_I; // src A of gemm1, may different with K, like int4 we use i8GEMM now
using PType = VType; // src A of gemm2, same type as V using PType = AccType_I; // src A of gemm2, may different with V, like int4 we use i8GEMM now
using OAccType = float; // always float, in case int8 FA using TailType = float; // always float, in case int8 FA
static constexpr int gemm1_vec_size = min(16 / sizeof(QCompute), 16 / sizeof(KType));
using q_vec_type = ext_vector_t<QCompute, gemm1_vec_size>;
using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>; using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>;
static constexpr int p_vec_elem = vector_traits<p_vec_type>::vector_size; static constexpr int p_vec_elem = vector_traits<p_vec_type>::vector_size;
...@@ -167,6 +165,12 @@ struct naive_attention_fwd_kernel ...@@ -167,6 +165,12 @@ struct naive_attention_fwd_kernel
__device__ void init(int i_b, int i_h) { base_ptr = get_base(i_b, i_h); } __device__ void init(int i_b, int i_h) { base_ptr = get_base(i_b, i_h); }
__device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; } __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
__device__ void store(T value, int i_s, int i_d) { base_ptr[get_offset(i_s, i_d)] = value; } __device__ void store(T value, int i_s, int i_d) { base_ptr[get_offset(i_s, i_d)] = value; }
template <int vec_size>
__device__ ext_vector_t<T, vec_size> load_vector(int i_s, int i_d)
{
return reinterpret_cast<ext_vector_t<T, vec_size>*>(base_ptr +
get_offset(i_s, i_d * vec_size))[0];
}
}; };
template <typename T, naive_attention_layout_enum Layout> template <typename T, naive_attention_layout_enum Layout>
...@@ -225,6 +229,12 @@ struct naive_attention_fwd_kernel ...@@ -225,6 +229,12 @@ struct naive_attention_fwd_kernel
__device__ void init(int /*i_b*/, int i_h_) { i_h = i_h_; } __device__ void init(int /*i_b*/, int i_h_) { i_h = i_h_; }
__device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; } __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
__device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {} __device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
template <int vec_size>
__device__ ext_vector_t<T, vec_size> load_vector(int i_s, int i_d)
{
return reinterpret_cast<ext_vector_t<T, vec_size>*>(base_ptr +
get_offset(i_s, i_d * vec_size))[0];
}
}; };
template <typename T, naive_attention_layout_enum Layout> template <typename T, naive_attention_layout_enum Layout>
...@@ -416,8 +426,8 @@ struct naive_attention_fwd_kernel ...@@ -416,8 +426,8 @@ struct naive_attention_fwd_kernel
SoftmaxType row_max = -numeric<SoftmaxType>::infinity(); SoftmaxType row_max = -numeric<SoftmaxType>::infinity();
SoftmaxType l{0}; SoftmaxType l{0};
// AccType o_acc = {0}; // AccType_O o_acc = {0};
OAccType o_acc = {0}; TailType o_acc = {0};
int sk_loops = (seqlen_kv + wg_size - 1) / wg_size; int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
QuantComputeType q_dequant_scale = .0f; QuantComputeType q_dequant_scale = .0f;
...@@ -428,21 +438,21 @@ struct naive_attention_fwd_kernel ...@@ -428,21 +438,21 @@ struct naive_attention_fwd_kernel
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD) if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{ {
// AccType is i32 now, seqlen_q = 1, hdim up to 256 // AccType_O is i32 now, seqlen_q = 1, hdim up to 256
AccType q = 0; AccType_O q = 0;
AccType k_s = 0; AccType_O k_s = 0;
if(static_cast<int>(threadIdx.x) < args.hdim) if(static_cast<int>(threadIdx.x) < args.hdim)
{ {
q = type_convert<AccType>(q_addr.load(0, threadIdx.x)); q = type_convert<AccType_O>(q_addr.load(0, threadIdx.x));
k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0)); k_s = type_convert<AccType_O>(kscale_addr.load(i_hk, threadIdx.x, 0));
} }
// 1) we apply the k scale to q // 1) we apply the k scale to q
AccType q_forwarded = q * k_s; AccType_O q_forwarded = q * k_s;
// 2) apply smooth-quant // 2) apply smooth-quant
// find absmax // find absmax
AccType qf_max = wave_reduce(q_forwarded, f_absmax_f32); AccType_O qf_max = wave_reduce(q_forwarded, f_absmax_f32);
qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType*>(smem)); qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType_O*>(smem));
// per-token scale // per-token scale
q_dequant_scale = type_convert<QuantComputeType>(qf_max) / scale_max<QCompute>::value; q_dequant_scale = type_convert<QuantComputeType>(qf_max) / scale_max<QCompute>::value;
...@@ -493,6 +503,40 @@ struct naive_attention_fwd_kernel ...@@ -493,6 +503,40 @@ struct naive_attention_fwd_kernel
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
} }
} }
else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_4BIT_PERTOKEN)
{
// current same with KV_8BIT_PERTOKEN, as we use 8bit mfma
if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
{
// dyanmic quant q here
float q = 0;
if(static_cast<int>(threadIdx.x) < args.hdim)
{
q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
}
// apply smooth-quant
// find absmax
float q_max = wave_reduce(q, f_absmax_f32);
q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));
// per-token scale
q_dequant_scale =
type_convert<QuantComputeType>(q_max) / scale_max<QCompute>::value;
// devide by scale
q = q / q_dequant_scale;
QCompute quantized_q = type_convert<QCompute>(q);
__syncthreads();
reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
__syncthreads();
// after above process, we have 2 data
// 1) fp8 q data stored in smem(no need to reload from global)
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
}
for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++) for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
{ {
...@@ -501,23 +545,33 @@ struct naive_attention_fwd_kernel ...@@ -501,23 +545,33 @@ struct naive_attention_fwd_kernel
SoftmaxType s_softmax = -numeric<SoftmaxType>::infinity(); SoftmaxType s_softmax = -numeric<SoftmaxType>::infinity();
if(i_sk < seqlen_kv) if(i_sk < seqlen_kv)
{ {
AccType s_acc{0}; // clear for every loop AccType_O s_acc{0}; // clear for every loop
for(auto i_dq = 0; i_dq < args.hdim; i_dq++) int gemm_1_loop = args.hdim / gemm1_vec_size;
for(auto i_loop = 0; i_loop < gemm_1_loop; i_loop++)
{ {
auto q = [&]() { auto q = [&]() {
if constexpr(Traits::quant_algo == if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERHEAD || naive_attention_quant_algo::KV_8BIT_PERHEAD ||
Traits::quant_algo == Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN) naive_attention_quant_algo::KV_8BIT_PERTOKEN ||
Traits::quant_algo ==
naive_attention_quant_algo::KV_4BIT_PERTOKEN)
{ {
return reinterpret_cast<QCompute*>(smem_quant_q)[i_dq]; return reinterpret_cast<q_vec_type*>(smem_quant_q)[i_loop];
} }
else else
return q_addr.load(i_sq, i_dq); // q will have duplicate load {
return q_addr.template load_vector<gemm1_vec_size>(i_sq, i_loop);
}
}();
auto k = [&]() {
return k_addr.template load_vector<gemm1_vec_size>(i_sk, i_loop);
}(); }();
auto k = [&]() { return k_addr.load(i_sk, i_dq); }();
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k); for(int i = 0; i < gemm1_vec_size; i++)
{
s_acc += type_convert<AccType_O>(q[i]) * type_convert<AccType_O>(k[i]);
}
} }
// scale // scale
s_softmax = type_convert<SoftmaxType>(s_acc); s_softmax = type_convert<SoftmaxType>(s_acc);
...@@ -528,7 +582,9 @@ struct naive_attention_fwd_kernel ...@@ -528,7 +582,9 @@ struct naive_attention_fwd_kernel
s_softmax *= q_dequant_scale; // post scale the per-token factor s_softmax *= q_dequant_scale; // post scale the per-token factor
} }
else if constexpr(Traits::quant_algo == else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN) naive_attention_quant_algo::KV_8BIT_PERTOKEN ||
Traits::quant_algo ==
naive_attention_quant_algo::KV_4BIT_PERTOKEN)
{ {
SoftmaxType k_per_token_scale = SoftmaxType k_per_token_scale =
type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0)); type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
...@@ -556,10 +612,10 @@ struct naive_attention_fwd_kernel ...@@ -556,10 +612,10 @@ struct naive_attention_fwd_kernel
// l, pre-scall o_acc // l, pre-scall o_acc
SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max); SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
l = tmp * l + row_sum; l = tmp * l + row_sum;
o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp); o_acc = type_convert<TailType>(type_convert<SoftmaxType>(o_acc) * tmp);
// prepare the p_compute into smem, to let every thread read same p_compute and do // prepare the p_compute into smem, to let every thread read same p_compute and
// 2nd gemm // do 2nd gemm
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD) if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{ {
QuantComputeType v_s = 0; QuantComputeType v_s = 0;
...@@ -631,7 +687,7 @@ struct naive_attention_fwd_kernel ...@@ -631,7 +687,7 @@ struct naive_attention_fwd_kernel
// gemm-2, simple loop over vector by vector // gemm-2, simple loop over vector by vector
constexpr int gemm_2_loop = wg_size / p_vec_elem; constexpr int gemm_2_loop = wg_size / p_vec_elem;
{ {
AccType o_acc_local = {0}; AccType_O o_acc_local = {0};
int sk_start = i_loop1 * wg_size; // we start from the first seqlen_kv element int sk_start = i_loop1 * wg_size; // we start from the first seqlen_kv element
for(int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++) for(int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
{ {
...@@ -648,29 +704,29 @@ struct naive_attention_fwd_kernel ...@@ -648,29 +704,29 @@ struct naive_attention_fwd_kernel
v = v_addr.load(i_sv, i_dv); v = v_addr.load(i_sv, i_dv);
} }
AccType v_compute = [&]() { return type_convert<AccType>(v); }(); AccType_O v_compute = [&]() { return type_convert<AccType_O>(v); }();
o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute; o_acc_local += type_convert<AccType_O>(p_vec[i_j]) * v_compute;
} }
} }
OAccType post_scale_o_acc_local = [&]() { TailType post_scale_o_acc_local = [&]() {
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD) if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{ {
// apply pr scale to local acc // apply pr scale to local acc
return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) * return type_convert<TailType>(type_convert<QuantComputeType>(o_acc_local) *
p_dequant_scale); p_dequant_scale);
} }
else if constexpr(Traits::quant_algo == else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN) naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{ {
// apply pr scale to local acc // apply pr scale to local acc
return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) * return type_convert<TailType>(type_convert<QuantComputeType>(o_acc_local) *
p_dequant_scale); p_dequant_scale);
} }
else else
{ {
return type_convert<OAccType>(o_acc_local); return type_convert<TailType>(o_acc_local);
} }
}(); }();
o_acc += post_scale_o_acc_local; o_acc += post_scale_o_acc_local;
...@@ -680,7 +736,7 @@ struct naive_attention_fwd_kernel ...@@ -680,7 +736,7 @@ struct naive_attention_fwd_kernel
// post scale o_acc // post scale o_acc
{ {
SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp); o_acc = type_convert<TailType>(type_convert<SoftmaxType>(o_acc) * tmp);
} }
// store O // store O
...@@ -698,7 +754,8 @@ struct naive_attention_fwd_kernel ...@@ -698,7 +754,8 @@ struct naive_attention_fwd_kernel
k_type_, \ k_type_, \
v_type_, \ v_type_, \
o_type_, \ o_type_, \
acc_type_, \ acc_type_i_, \
acc_type_o_, \
kvscale_type_, \ kvscale_type_, \
q_layout_, \ q_layout_, \
k_layout_, \ k_layout_, \
...@@ -764,7 +821,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, ...@@ -764,7 +821,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using k_type_ = fp16_t; using k_type_ = fp16_t;
using v_type_ = fp16_t; using v_type_ = fp16_t;
using o_type_ = fp16_t; using o_type_ = fp16_t;
using acc_type_ = float; using acc_type_i_ = fp16_t;
using acc_type_o_ = float;
using kvscale_type_ = float; using kvscale_type_ = float;
constexpr int quant_algo_ = 0; constexpr int quant_algo_ = 0;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
...@@ -776,7 +834,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, ...@@ -776,7 +834,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using k_type_ = bf16_t; using k_type_ = bf16_t;
using v_type_ = bf16_t; using v_type_ = bf16_t;
using o_type_ = bf16_t; using o_type_ = bf16_t;
using acc_type_ = float; using acc_type_i_ = fp16_t;
using acc_type_o_ = float;
using kvscale_type_ = float; using kvscale_type_ = float;
constexpr int quant_algo_ = 0; constexpr int quant_algo_ = 0;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
...@@ -788,7 +847,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, ...@@ -788,7 +847,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using k_type_ = fp8_t; using k_type_ = fp8_t;
using v_type_ = fp8_t; using v_type_ = fp8_t;
using o_type_ = bf16_t; using o_type_ = bf16_t;
using acc_type_ = float; // NOTE! using acc_type_i_ = fp8_t;
using acc_type_o_ = float;
using kvscale_type_ = float; using kvscale_type_ = float;
constexpr int quant_algo_ = 2; constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
...@@ -800,7 +860,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, ...@@ -800,7 +860,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using k_type_ = fp8_t; using k_type_ = fp8_t;
using v_type_ = fp8_t; using v_type_ = fp8_t;
using o_type_ = fp16_t; using o_type_ = fp16_t;
using acc_type_ = float; // NOTE! using acc_type_i_ = fp8_t;
using acc_type_o_ = float;
using kvscale_type_ = float; using kvscale_type_ = float;
constexpr int quant_algo_ = 2; constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
...@@ -812,7 +873,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, ...@@ -812,7 +873,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using k_type_ = int8_t; using k_type_ = int8_t;
using v_type_ = int8_t; using v_type_ = int8_t;
using o_type_ = bf16_t; using o_type_ = bf16_t;
using acc_type_ = int32_t; // NOTE! using acc_type_i_ = int8_t;
using acc_type_o_ = int32_t;
using kvscale_type_ = float; using kvscale_type_ = float;
constexpr int quant_algo_ = 2; constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment