Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9f9f3796
Commit
9f9f3796
authored
Aug 06, 2024
by
zhangshao
Browse files
恢复对bf16的支持
parent
f38bd872
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
40 deletions
+9
-40
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+3
-34
csrc/attention/attention_utils.cuh
csrc/attention/attention_utils.cuh
+6
-6
No files found.
csrc/attention/attention_kernels.cu
View file @
9f9f3796
...
...
@@ -68,40 +68,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
=
1
,
bool
odd_nheads
=
false
,
int
PARTITION_SIZE
=
0
,
std
::
enable_if_t
<!
std
::
is_same
<
scalar_t
,
uint16_t
>
::
value
,
int
>
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
=
1
,
bool
odd_nheads
=
false
,
int
PARTITION_SIZE
=
0
,
std
::
enable_if_t
<
std
::
is_same
<
scalar_t
,
uint16_t
>
::
value
,
int
>
=
0
>
// Zero means no partitioning.
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
...
...
@@ -133,6 +100,7 @@ __device__ void paged_attention_kernel(
// No work to do. Terminate the thread block.
return
;
}
if
constexpr
(
sizeof
(
scalar_t
)
==
2
){
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
...
...
@@ -723,6 +691,7 @@ __device__ void paged_attention_kernel(
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
...
...
csrc/attention/attention_utils.cuh
View file @
9f9f3796
...
...
@@ -80,8 +80,8 @@ inline __device__ void v_pk_fma_f16x8(float& a,const uint4 & b,const uint4 & c
}
// Q*K^T operation. fp16
//
template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
,
typename
scalar_t
,
std
::
enable_if_t
<
std
::
is_same
<
scalar_t
,
uint16_t
>
::
value
,
int
>
=
0
>
//
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline
__device__
float
qk_dot_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
float
qk
=
0
;
...
...
@@ -114,9 +114,9 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
}
// Q*K^T operation. //bf16
//
template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
inline
__device__
float
qk_dot_
vpack_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
,
typename
scalar_t
,
std
::
enable_if_t
<!
std
::
is_same
<
scalar_t
,
uint16_t
>
::
value
,
int
>
=
0
>
//
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline
__device__
float
qk_dot_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
using
A_vec
=
typename
FloatVec
<
Vec
>::
Type
;
A_vec
qk_vec
=
mul
<
A_vec
,
Vec
,
Vec
>
(
q
[
0
],
k
[
0
]);
...
...
@@ -138,7 +138,7 @@ template <typename T, int THREAD_GROUP_SIZE>
struct
Qk_dot
{
template
<
typename
Vec
,
int
N
>
static
inline
__device__
float
dot
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
return
qk_dot_
<
THREAD_GROUP_SIZE
,
Vec
,
N
,
T
>
(
q
,
k
);
}
// template <typename Vec, int N>
// static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment