Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
994989c0
Commit
994989c0
authored
Feb 10, 2025
by
valarLip
Browse files
add draft int4 naive pa
parent
a8c5bd9b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
186 additions
and
59 deletions
+186
-59
include/ck_tile/core/numeric/int4.hpp
include/ck_tile/core/numeric/int4.hpp
+61
-0
include/ck_tile/core/numeric/type_convert.hpp
include/ck_tile/core/numeric/type_convert.hpp
+4
-0
include/ck_tile/ref/naive_attention.hpp
include/ck_tile/ref/naive_attention.hpp
+121
-59
No files found.
include/ck_tile/core/numeric/int4.hpp
0 → 100644
View file @
994989c0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace
ck_tile
{
// 8 bit int4
struct
int4x2_t
{
uint8_t
raw
;
CK_TILE_HOST_DEVICE
constexpr
int4x2_t
()
:
raw
{
uint8_t
{}}
{}
// CK_TILE_HOST_DEVICE constexpr int4x2_t(uint8_t init) : raw{((init & 0x0f) << 4) | (init & 0x0f)}
// {
// }
};
CK_TILE_HOST_DEVICE
constexpr
fp32x2_t
int4x2_to_floatx2
(
const
int4x2_t
&
x
)
{
auto
x_u8
=
x
.
raw
;
// naive implement
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
);
if
(
x_h
>=
8
)
{
x_h
-=
16
;
}
float
x_l
=
(
x_u8
&
0x0f
);
if
(
x_l
>=
8
)
{
x_l
-=
16
;
}
return
{
x_h
,
x_l
};
}
CK_TILE_HOST_DEVICE
constexpr
int4x2_t
floatx2_to_int4x2
(
const
fp32x2_t
&
x
)
{
// naive implement
int4x2_t
res
;
auto
x_l
=
static_cast
<
int8_t
>
(
x
.
x
);
auto
x_h
=
static_cast
<
int8_t
>
(
x
.
y
);
res
.
raw
=
(
x_l
<<
4
)
|
(
x_h
&
0x0F
);
return
res
;
}
}
// namespace ck_tile
include/ck_tile/core/numeric/type_convert.hpp
View file @
994989c0
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/int4.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -64,6 +65,9 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
...
@@ -64,6 +65,9 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
int8_t
,
int8
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
int8_t
,
int8
)
CK_TILE_TYPE_CONVERT
(
int8_t
,
int8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
int8_t
,
int8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
fp32x2_t
,
floatx2
,
int4x2_t
,
int4x2
)
CK_TILE_TYPE_CONVERT
(
int4x2_t
,
int4x2
,
fp32x2_t
,
floatx2
)
#undef CK_TILE_TYPE_CONVERT
#undef CK_TILE_TYPE_CONVERT
#endif
#endif
...
...
include/ck_tile/ref/naive_attention.hpp
View file @
994989c0
...
@@ -42,6 +42,8 @@ enum class naive_attention_quant_algo
...
@@ -42,6 +42,8 @@ enum class naive_attention_quant_algo
// FP8/INT8 quant for KVCache, per-token quant
// FP8/INT8 quant for KVCache, per-token quant
// [num_tokens, nhead, hdim] -> [nhead, num_tokens]
// [num_tokens, nhead, hdim] -> [nhead, num_tokens]
KV_8BIT_PERTOKEN
=
2
,
KV_8BIT_PERTOKEN
=
2
,
// same as 8bit per token quant but 4 bit
KV_4BIT_PERTOKEN
=
3
,
};
};
// TODO: for simplicity, this will be used as host/device arg
// TODO: for simplicity, this will be used as host/device arg
...
@@ -100,7 +102,8 @@ template <typename QType,
...
@@ -100,7 +102,8 @@ template <typename QType,
typename
KType
,
typename
KType
,
typename
VType
,
typename
VType
,
typename
OType
,
typename
OType
,
typename
AccType
,
typename
AccType_I
,
// i.e. input of mfma
typename
AccType_O
,
// i.e. results of mfma
typename
KVScaleType
,
typename
KVScaleType
,
naive_attention_layout_enum
QLayout
,
naive_attention_layout_enum
QLayout
,
naive_attention_layout_enum
KLayout
,
naive_attention_layout_enum
KLayout
,
...
@@ -111,20 +114,15 @@ template <typename QType,
...
@@ -111,20 +114,15 @@ template <typename QType,
typename
Traits
>
typename
Traits
>
struct
naive_attention_fwd_kernel
struct
naive_attention_fwd_kernel
{
{
static
constexpr
bool
is_kvcache_i8
=
std
::
is_same_v
<
KType
,
int8_t
>
&&
std
::
is_same_v
<
VType
,
int8_t
>
;
static
constexpr
bool
is_kvcache_fp8
=
std
::
is_same_v
<
KType
,
fp8_t
>
&&
std
::
is_same_v
<
VType
,
fp8_t
>
;
static
constexpr
int
v_per_token_quant_group_size
=
64
;
// TODO: hardcode
// TODO: hardcode
using
SoftmaxType
=
float
;
// always using float to do softmax compute
using
SoftmaxType
=
float
;
// always using float to do softmax compute
using
QuantComputeType
=
float
;
// used for quant/dequant scale compute
using
QuantComputeType
=
float
;
// used for quant/dequant scale compute
using
QCompute
=
K
Type
;
// src A of gemm1,
same type as K
using
QCompute
=
Acc
Type
_I
;
// src A of gemm1,
may different with K, like int4 we use i8GEMM now
using
PType
=
V
Type
;
// src A of gemm2,
same type as V
using
PType
=
Acc
Type
_I
;
// src A of gemm2,
may different with V, like int4 we use i8GEMM now
using
OAcc
Type
=
float
;
// always float, in case int8 FA
using
Tail
Type
=
float
;
// always float, in case int8 FA
static
constexpr
int
gemm1_vec_size
=
min
(
16
/
sizeof
(
QCompute
),
16
/
sizeof
(
KType
));
using
q_vec_type
=
ext_vector_t
<
QCompute
,
gemm1_vec_size
>
;
using
p_vec_type
=
ext_vector_t
<
PType
,
16
/
sizeof
(
PType
)
>
;
using
p_vec_type
=
ext_vector_t
<
PType
,
16
/
sizeof
(
PType
)
>
;
static
constexpr
int
p_vec_elem
=
vector_traits
<
p_vec_type
>::
vector_size
;
static
constexpr
int
p_vec_elem
=
vector_traits
<
p_vec_type
>::
vector_size
;
...
@@ -167,6 +165,12 @@ struct naive_attention_fwd_kernel
...
@@ -167,6 +165,12 @@ struct naive_attention_fwd_kernel
__device__
void
init
(
int
i_b
,
int
i_h
)
{
base_ptr
=
get_base
(
i_b
,
i_h
);
}
__device__
void
init
(
int
i_b
,
int
i_h
)
{
base_ptr
=
get_base
(
i_b
,
i_h
);
}
__device__
T
load
(
int
i_s
,
int
i_d
)
{
return
base_ptr
[
get_offset
(
i_s
,
i_d
)];
}
__device__
T
load
(
int
i_s
,
int
i_d
)
{
return
base_ptr
[
get_offset
(
i_s
,
i_d
)];
}
__device__
void
store
(
T
value
,
int
i_s
,
int
i_d
)
{
base_ptr
[
get_offset
(
i_s
,
i_d
)]
=
value
;
}
__device__
void
store
(
T
value
,
int
i_s
,
int
i_d
)
{
base_ptr
[
get_offset
(
i_s
,
i_d
)]
=
value
;
}
template
<
int
vec_size
>
__device__
ext_vector_t
<
T
,
vec_size
>
load_vector
(
int
i_s
,
int
i_d
)
{
return
reinterpret_cast
<
ext_vector_t
<
T
,
vec_size
>*>
(
base_ptr
+
get_offset
(
i_s
,
i_d
*
vec_size
))[
0
];
}
};
};
template
<
typename
T
,
naive_attention_layout_enum
Layout
>
template
<
typename
T
,
naive_attention_layout_enum
Layout
>
...
@@ -225,6 +229,12 @@ struct naive_attention_fwd_kernel
...
@@ -225,6 +229,12 @@ struct naive_attention_fwd_kernel
__device__
void
init
(
int
/*i_b*/
,
int
i_h_
)
{
i_h
=
i_h_
;
}
__device__
void
init
(
int
/*i_b*/
,
int
i_h_
)
{
i_h
=
i_h_
;
}
__device__
T
load
(
int
i_s
,
int
i_d
)
{
return
base_ptr
[
get_offset
(
i_s
,
i_d
)];
}
__device__
T
load
(
int
i_s
,
int
i_d
)
{
return
base_ptr
[
get_offset
(
i_s
,
i_d
)];
}
__device__
void
store
(
T
/*value*/
,
int
/*i_s*/
,
int
/*i_d*/
)
{}
__device__
void
store
(
T
/*value*/
,
int
/*i_s*/
,
int
/*i_d*/
)
{}
template
<
int
vec_size
>
__device__
ext_vector_t
<
T
,
vec_size
>
load_vector
(
int
i_s
,
int
i_d
)
{
return
reinterpret_cast
<
ext_vector_t
<
T
,
vec_size
>*>
(
base_ptr
+
get_offset
(
i_s
,
i_d
*
vec_size
))[
0
];
}
};
};
template
<
typename
T
,
naive_attention_layout_enum
Layout
>
template
<
typename
T
,
naive_attention_layout_enum
Layout
>
...
@@ -416,8 +426,8 @@ struct naive_attention_fwd_kernel
...
@@ -416,8 +426,8 @@ struct naive_attention_fwd_kernel
SoftmaxType
row_max
=
-
numeric
<
SoftmaxType
>::
infinity
();
SoftmaxType
row_max
=
-
numeric
<
SoftmaxType
>::
infinity
();
SoftmaxType
l
{
0
};
SoftmaxType
l
{
0
};
// AccType o_acc = {0};
// AccType
_O
o_acc = {0};
OAcc
Type
o_acc
=
{
0
};
Tail
Type
o_acc
=
{
0
};
int
sk_loops
=
(
seqlen_kv
+
wg_size
-
1
)
/
wg_size
;
int
sk_loops
=
(
seqlen_kv
+
wg_size
-
1
)
/
wg_size
;
QuantComputeType
q_dequant_scale
=
.0
f
;
QuantComputeType
q_dequant_scale
=
.0
f
;
...
@@ -428,21 +438,21 @@ struct naive_attention_fwd_kernel
...
@@ -428,21 +438,21 @@ struct naive_attention_fwd_kernel
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
{
{
// AccType is i32 now, seqlen_q = 1, hdim up to 256
// AccType
_O
is i32 now, seqlen_q = 1, hdim up to 256
AccType
q
=
0
;
AccType
_O
q
=
0
;
AccType
k_s
=
0
;
AccType
_O
k_s
=
0
;
if
(
static_cast
<
int
>
(
threadIdx
.
x
)
<
args
.
hdim
)
if
(
static_cast
<
int
>
(
threadIdx
.
x
)
<
args
.
hdim
)
{
{
q
=
type_convert
<
AccType
>
(
q_addr
.
load
(
0
,
threadIdx
.
x
));
q
=
type_convert
<
AccType
_O
>
(
q_addr
.
load
(
0
,
threadIdx
.
x
));
k_s
=
type_convert
<
AccType
>
(
kscale_addr
.
load
(
i_hk
,
threadIdx
.
x
,
0
));
k_s
=
type_convert
<
AccType
_O
>
(
kscale_addr
.
load
(
i_hk
,
threadIdx
.
x
,
0
));
}
}
// 1) we apply the k scale to q
// 1) we apply the k scale to q
AccType
q_forwarded
=
q
*
k_s
;
AccType
_O
q_forwarded
=
q
*
k_s
;
// 2) apply smooth-quant
// 2) apply smooth-quant
// find absmax
// find absmax
AccType
qf_max
=
wave_reduce
(
q_forwarded
,
f_absmax_f32
);
AccType
_O
qf_max
=
wave_reduce
(
q_forwarded
,
f_absmax_f32
);
qf_max
=
cross_wave_reduce
(
qf_max
,
f_absmax_f32
,
reinterpret_cast
<
AccType
*>
(
smem
));
qf_max
=
cross_wave_reduce
(
qf_max
,
f_absmax_f32
,
reinterpret_cast
<
AccType
_O
*>
(
smem
));
// per-token scale
// per-token scale
q_dequant_scale
=
type_convert
<
QuantComputeType
>
(
qf_max
)
/
scale_max
<
QCompute
>::
value
;
q_dequant_scale
=
type_convert
<
QuantComputeType
>
(
qf_max
)
/
scale_max
<
QCompute
>::
value
;
...
@@ -493,6 +503,40 @@ struct naive_attention_fwd_kernel
...
@@ -493,6 +503,40 @@ struct naive_attention_fwd_kernel
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
}
}
}
else
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_4BIT_PERTOKEN
)
{
// current same with KV_8BIT_PERTOKEN, as we use 8bit mfma
if
(
std
::
is_same_v
<
QType
,
fp16_t
>
||
std
::
is_same_v
<
QType
,
bf16_t
>
)
{
// dyanmic quant q here
float
q
=
0
;
if
(
static_cast
<
int
>
(
threadIdx
.
x
)
<
args
.
hdim
)
{
q
=
type_convert
<
float
>
(
q_addr
.
load
(
i_sq
,
threadIdx
.
x
));
}
// apply smooth-quant
// find absmax
float
q_max
=
wave_reduce
(
q
,
f_absmax_f32
);
q_max
=
cross_wave_reduce
(
q_max
,
f_absmax_f32
,
reinterpret_cast
<
float
*>
(
smem
));
// per-token scale
q_dequant_scale
=
type_convert
<
QuantComputeType
>
(
q_max
)
/
scale_max
<
QCompute
>::
value
;
// devide by scale
q
=
q
/
q_dequant_scale
;
QCompute
quantized_q
=
type_convert
<
QCompute
>
(
q
);
__syncthreads
();
reinterpret_cast
<
QCompute
*>
(
smem_quant_q
)[
threadIdx
.
x
]
=
quantized_q
;
__syncthreads
();
// after above process, we have 2 data
// 1) fp8 q data stored in smem(no need to reload from global)
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
}
for
(
int
i_loop1
=
0
;
i_loop1
<
sk_loops
;
i_loop1
++
)
for
(
int
i_loop1
=
0
;
i_loop1
<
sk_loops
;
i_loop1
++
)
{
{
...
@@ -501,23 +545,33 @@ struct naive_attention_fwd_kernel
...
@@ -501,23 +545,33 @@ struct naive_attention_fwd_kernel
SoftmaxType
s_softmax
=
-
numeric
<
SoftmaxType
>::
infinity
();
SoftmaxType
s_softmax
=
-
numeric
<
SoftmaxType
>::
infinity
();
if
(
i_sk
<
seqlen_kv
)
if
(
i_sk
<
seqlen_kv
)
{
{
AccType
s_acc
{
0
};
// clear for every loop
AccType_O
s_acc
{
0
};
// clear for every loop
for
(
auto
i_dq
=
0
;
i_dq
<
args
.
hdim
;
i_dq
++
)
int
gemm_1_loop
=
args
.
hdim
/
gemm1_vec_size
;
for
(
auto
i_loop
=
0
;
i_loop
<
gemm_1_loop
;
i_loop
++
)
{
{
auto
q
=
[
&
]()
{
auto
q
=
[
&
]()
{
if
constexpr
(
Traits
::
quant_algo
==
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
||
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
||
Traits
::
quant_algo
==
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
)
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
||
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_4BIT_PERTOKEN
)
{
{
return
reinterpret_cast
<
QComput
e
*>
(
smem_quant_q
)[
i_
dq
];
return
reinterpret_cast
<
q_vec_typ
e
*>
(
smem_quant_q
)[
i_
loop
];
}
}
else
else
return
q_addr
.
load
(
i_sq
,
i_dq
);
// q will have duplicate load
{
return
q_addr
.
template
load_vector
<
gemm1_vec_size
>(
i_sq
,
i_loop
);
}
}();
auto
k
=
[
&
]()
{
return
k_addr
.
template
load_vector
<
gemm1_vec_size
>(
i_sk
,
i_loop
);
}();
}();
auto
k
=
[
&
]()
{
return
k_addr
.
load
(
i_sk
,
i_dq
);
}();
s_acc
+=
type_convert
<
AccType
>
(
q
)
*
type_convert
<
AccType
>
(
k
);
for
(
int
i
=
0
;
i
<
gemm1_vec_size
;
i
++
)
{
s_acc
+=
type_convert
<
AccType_O
>
(
q
[
i
])
*
type_convert
<
AccType_O
>
(
k
[
i
]);
}
}
}
// scale
// scale
s_softmax
=
type_convert
<
SoftmaxType
>
(
s_acc
);
s_softmax
=
type_convert
<
SoftmaxType
>
(
s_acc
);
...
@@ -528,7 +582,9 @@ struct naive_attention_fwd_kernel
...
@@ -528,7 +582,9 @@ struct naive_attention_fwd_kernel
s_softmax
*=
q_dequant_scale
;
// post scale the per-token factor
s_softmax
*=
q_dequant_scale
;
// post scale the per-token factor
}
}
else
if
constexpr
(
Traits
::
quant_algo
==
else
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
)
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
||
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_4BIT_PERTOKEN
)
{
{
SoftmaxType
k_per_token_scale
=
SoftmaxType
k_per_token_scale
=
type_convert
<
SoftmaxType
>
(
kscale_addr
.
load
(
i_sk
,
i_hk
,
0
));
type_convert
<
SoftmaxType
>
(
kscale_addr
.
load
(
i_sk
,
i_hk
,
0
));
...
@@ -556,10 +612,10 @@ struct naive_attention_fwd_kernel
...
@@ -556,10 +612,10 @@ struct naive_attention_fwd_kernel
// l, pre-scall o_acc
// l, pre-scall o_acc
SoftmaxType
tmp
=
__builtin_amdgcn_exp2f
(
old_max
-
row_max
);
SoftmaxType
tmp
=
__builtin_amdgcn_exp2f
(
old_max
-
row_max
);
l
=
tmp
*
l
+
row_sum
;
l
=
tmp
*
l
+
row_sum
;
o_acc
=
type_convert
<
OAcc
Type
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
o_acc
=
type_convert
<
Tail
Type
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
// prepare the p_compute into smem, to let every thread read same p_compute and
do
// prepare the p_compute into smem, to let every thread read same p_compute and
// 2nd gemm
//
do
2nd gemm
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
{
{
QuantComputeType
v_s
=
0
;
QuantComputeType
v_s
=
0
;
...
@@ -631,7 +687,7 @@ struct naive_attention_fwd_kernel
...
@@ -631,7 +687,7 @@ struct naive_attention_fwd_kernel
// gemm-2, simple loop over vector by vector
// gemm-2, simple loop over vector by vector
constexpr
int
gemm_2_loop
=
wg_size
/
p_vec_elem
;
constexpr
int
gemm_2_loop
=
wg_size
/
p_vec_elem
;
{
{
AccType
o_acc_local
=
{
0
};
AccType
_O
o_acc_local
=
{
0
};
int
sk_start
=
i_loop1
*
wg_size
;
// we start from the first seqlen_kv element
int
sk_start
=
i_loop1
*
wg_size
;
// we start from the first seqlen_kv element
for
(
int
i_loop2
=
0
;
i_loop2
<
gemm_2_loop
;
i_loop2
++
)
for
(
int
i_loop2
=
0
;
i_loop2
<
gemm_2_loop
;
i_loop2
++
)
{
{
...
@@ -648,29 +704,29 @@ struct naive_attention_fwd_kernel
...
@@ -648,29 +704,29 @@ struct naive_attention_fwd_kernel
v
=
v_addr
.
load
(
i_sv
,
i_dv
);
v
=
v_addr
.
load
(
i_sv
,
i_dv
);
}
}
AccType
v_compute
=
[
&
]()
{
return
type_convert
<
AccType
>
(
v
);
}();
AccType
_O
v_compute
=
[
&
]()
{
return
type_convert
<
AccType
_O
>
(
v
);
}();
o_acc_local
+=
type_convert
<
AccType
>
(
p_vec
[
i_j
])
*
v_compute
;
o_acc_local
+=
type_convert
<
AccType
_O
>
(
p_vec
[
i_j
])
*
v_compute
;
}
}
}
}
OAcc
Type
post_scale_o_acc_local
=
[
&
]()
{
Tail
Type
post_scale_o_acc_local
=
[
&
]()
{
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
{
{
// apply pr scale to local acc
// apply pr scale to local acc
return
type_convert
<
OAcc
Type
>
(
type_convert
<
QuantComputeType
>
(
o_acc_local
)
*
return
type_convert
<
Tail
Type
>
(
type_convert
<
QuantComputeType
>
(
o_acc_local
)
*
p_dequant_scale
);
p_dequant_scale
);
}
}
else
if
constexpr
(
Traits
::
quant_algo
==
else
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
)
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
)
{
{
// apply pr scale to local acc
// apply pr scale to local acc
return
type_convert
<
OAcc
Type
>
(
type_convert
<
QuantComputeType
>
(
o_acc_local
)
*
return
type_convert
<
Tail
Type
>
(
type_convert
<
QuantComputeType
>
(
o_acc_local
)
*
p_dequant_scale
);
p_dequant_scale
);
}
}
else
else
{
{
return
type_convert
<
OAcc
Type
>
(
o_acc_local
);
return
type_convert
<
Tail
Type
>
(
o_acc_local
);
}
}
}();
}();
o_acc
+=
post_scale_o_acc_local
;
o_acc
+=
post_scale_o_acc_local
;
...
@@ -680,7 +736,7 @@ struct naive_attention_fwd_kernel
...
@@ -680,7 +736,7 @@ struct naive_attention_fwd_kernel
// post scale o_acc
// post scale o_acc
{
{
SoftmaxType
tmp
=
l
==
0.
f
?
0.
f
:
1.
f
/
l
;
// in case masking
SoftmaxType
tmp
=
l
==
0.
f
?
0.
f
:
1.
f
/
l
;
// in case masking
o_acc
=
type_convert
<
OAcc
Type
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
o_acc
=
type_convert
<
Tail
Type
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
}
}
// store O
// store O
...
@@ -698,7 +754,8 @@ struct naive_attention_fwd_kernel
...
@@ -698,7 +754,8 @@ struct naive_attention_fwd_kernel
k_type_, \
k_type_, \
v_type_, \
v_type_, \
o_type_, \
o_type_, \
acc_type_, \
acc_type_i_, \
acc_type_o_, \
kvscale_type_, \
kvscale_type_, \
q_layout_, \
q_layout_, \
k_layout_, \
k_layout_, \
...
@@ -764,7 +821,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
...
@@ -764,7 +821,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using
k_type_
=
fp16_t
;
using
k_type_
=
fp16_t
;
using
v_type_
=
fp16_t
;
using
v_type_
=
fp16_t
;
using
o_type_
=
fp16_t
;
using
o_type_
=
fp16_t
;
using
acc_type_
=
float
;
using
acc_type_i_
=
fp16_t
;
using
acc_type_o_
=
float
;
using
kvscale_type_
=
float
;
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
0
;
constexpr
int
quant_algo_
=
0
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
...
@@ -776,7 +834,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
...
@@ -776,7 +834,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using
k_type_
=
bf16_t
;
using
k_type_
=
bf16_t
;
using
v_type_
=
bf16_t
;
using
v_type_
=
bf16_t
;
using
o_type_
=
bf16_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
float
;
using
acc_type_i_
=
fp16_t
;
using
acc_type_o_
=
float
;
using
kvscale_type_
=
float
;
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
0
;
constexpr
int
quant_algo_
=
0
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
...
@@ -788,7 +847,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
...
@@ -788,7 +847,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using
k_type_
=
fp8_t
;
using
k_type_
=
fp8_t
;
using
v_type_
=
fp8_t
;
using
v_type_
=
fp8_t
;
using
o_type_
=
bf16_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
float
;
// NOTE!
using
acc_type_i_
=
fp8_t
;
using
acc_type_o_
=
float
;
using
kvscale_type_
=
float
;
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
2
;
constexpr
int
quant_algo_
=
2
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
...
@@ -800,7 +860,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
...
@@ -800,7 +860,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using
k_type_
=
fp8_t
;
using
k_type_
=
fp8_t
;
using
v_type_
=
fp8_t
;
using
v_type_
=
fp8_t
;
using
o_type_
=
fp16_t
;
using
o_type_
=
fp16_t
;
using
acc_type_
=
float
;
// NOTE!
using
acc_type_i_
=
fp8_t
;
using
acc_type_o_
=
float
;
using
kvscale_type_
=
float
;
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
2
;
constexpr
int
quant_algo_
=
2
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
...
@@ -812,7 +873,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
...
@@ -812,7 +873,8 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
using
k_type_
=
int8_t
;
using
k_type_
=
int8_t
;
using
v_type_
=
int8_t
;
using
v_type_
=
int8_t
;
using
o_type_
=
bf16_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
int32_t
;
// NOTE!
using
acc_type_i_
=
int8_t
;
using
acc_type_o_
=
int32_t
;
using
kvscale_type_
=
float
;
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
2
;
constexpr
int
quant_algo_
=
2
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment