Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2ddd2671
Commit
2ddd2671
authored
May 13, 2025
by
zhuwenwen
Browse files
Merge branch 'vllm-0.8.5-zhangshao' into 'v0.8.5.post1-dev'
提升bf16 pa精度 See merge request dcutoolkit/deeplearing/vllm!112
parents
09e372e7
98955c1f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
38 deletions
+59
-38
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+27
-15
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+32
-23
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
2ddd2671
...
...
@@ -237,7 +237,7 @@ __global__ void paged_attention_kernel_TC(
}
__syncthreads
();
extern
__shared__
char
shared_mem
[];
scalar_
t
*
logits
=
reinterpret_cast
<
scalar_
t
*>
(
shared_mem
);
floa
t
*
logits
=
reinterpret_cast
<
floa
t
*>
(
shared_mem
);
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
...
...
@@ -277,10 +277,10 @@ __global__ void paged_attention_kernel_TC(
}
const
bool
mask
=
(
token_idx
>=
seq_len
);
if
(
mask
){
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
)
;
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
=
0.
f
;
}
else
{
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
qk_vec
[
i
]
)
;
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
=
qk_vec
[
i
];
qk_max
[
i
]
=
fmaxf
(
qk_max
[
i
],
qk_vec
[
i
]);
}
}
...
...
@@ -313,15 +313,15 @@ __global__ void paged_attention_kernel_TC(
}
qk_max_tmp
=
__shfl
(
qk_max_tmp
,
0
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
)
-
qk_max_tmp
);
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
val
)
;
float
val
=
__expf
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
-
qk_max_tmp
);
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
val
;
exp_sum
+=
val
;
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
s_logit
,
exp_sum
);
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
)
*
inv_sum
)
;
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
*
inv_sum
;
}
if
(
USE_PARTITIONING
&&
thread_idx
==
0
){
max_out
[
reuse_kv_idx
]
=
qk_max_tmp
;
...
...
@@ -349,7 +349,11 @@ __global__ void paged_attention_kernel_TC(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
4
*
q_boundary
){
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
rows
*
4
+
rowid
*
16
;
...
...
@@ -441,7 +445,11 @@ __global__ void paged_attention_kernel_TC(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
rows
*
4
+
rowid
*
16
;
...
...
@@ -542,7 +550,11 @@ __global__ void paged_attention_kernel_TC(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
rows
*
4
+
rowid
*
16
;
...
...
@@ -837,8 +849,8 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if
(
device_name
==
"gfx928"
){
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
max_num_partitions
=
1
;
if
(
max_seq_len
<
39
00
)
reusekv
=
8
;
else
if
(
max_seq_len
<
78
00
)
reusekv
=
4
;
if
(
max_seq_len
<
20
00
)
reusekv
=
8
;
else
if
(
max_seq_len
<
39
00
)
reusekv
=
4
;
else
{
PARTITION_SIZE
=
2048
;
reusekv
=
8
;
...
...
@@ -867,7 +879,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if
(
device_name
==
"gfx928"
){
if
(
batchsize
*
qheads
>
1024
&&
max_seq_len
>=
2000
){
max_num_partitions
=
1
;
if
(
max_seq_len
<
78
00
)
reusekv
=
4
;
if
(
max_seq_len
<
39
00
)
reusekv
=
4
;
else
{
PARTITION_SIZE
=
2048
;
reusekv
=
4
;
...
...
@@ -880,7 +892,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
max_seq_len
<=
1500
&&
(
qheads
>
4
&&
batchsize
>=
16
||
batchsize
>=
64
))
max_num_partitions
=
1
;
int
blocks
=
max_num_partitions
*
batchsize
*
qheads
;
if
(
blocks
>=
150
||
batchsize
>=
16
||
qheads
>=
8
&&
(
batchsize
>=
4
||
max_seq_len
>=
2000
))
reusekv
=
4
;
if
(
blocks
>=
150
||
batchsize
>=
16
||
qheads
>=
8
&&
(
batchsize
>=
4
||
(
max_seq_len
>=
2000
&&
max_seq_len
<
3900
)
))
reusekv
=
4
;
}
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
...
...
@@ -948,7 +960,7 @@ void paged_attention_v2_launcher_opt_tc(
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
4
;
if
(
max_num_partitions
==
1
)
PARTITION_SIZE
=
0
;
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
;
...
...
@@ -1051,4 +1063,4 @@ void paged_attention_v1_opt_tc(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
#undef DIVIDE_ROUND_UP
csrc/attention/attention_with_mask_kernels_opt_tc.cu
View file @
2ddd2671
...
...
@@ -221,7 +221,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
__syncthreads
();
extern
__shared__
char
shared_mem
[];
scalar_
t
*
logits
=
reinterpret_cast
<
scalar_
t
*>
(
shared_mem
);
floa
t
*
logits
=
reinterpret_cast
<
floa
t
*>
(
shared_mem
);
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
...
...
@@ -268,10 +268,10 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
const
bool
mask
=
(
token_idx
>=
seq_len
);
if
(
mask
){
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
)
;
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
=
0.
f
;
}
else
{
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
qk_vec
[
i
]
)
;
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
=
qk_vec
[
i
];
qk_max
[
i
]
=
fmaxf
(
qk_max
[
i
],
qk_vec
[
i
]);
}
}
...
...
@@ -304,15 +304,15 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
qk_max_tmp
=
__shfl
(
qk_max_tmp
,
0
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
)
-
qk_max_tmp
);
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
val
)
;
float
val
=
__expf
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
-
qk_max_tmp
);
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
val
;
exp_sum
+=
val
;
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
s_logit
,
exp_sum
);
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
)
*
inv_sum
)
;
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
*
inv_sum
;
}
if
(
USE_PARTITIONING
&&
thread_idx
==
0
){
max_out
[
reuse_kv_idx
]
=
qk_max_tmp
;
...
...
@@ -340,7 +340,11 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
4
*
q_boundary
){
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
rows
*
4
+
rowid
*
16
;
...
...
@@ -432,7 +436,11 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
rows
*
4
+
rowid
*
16
;
...
...
@@ -533,7 +541,11 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
rows
*
4
+
rowid
*
16
;
...
...
@@ -856,7 +868,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
4
;
if
(
max_num_partitions
==
1
)
PARTITION_SIZE
=
0
;
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
;
...
...
@@ -883,13 +895,10 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
blocksparse_head_sliding_step,attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
if (is_block_sparse) { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
...
...
@@ -933,8 +942,8 @@ void paged_attention_v2_opt_tc_with_mask(
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
DISPATCH_BY_KV_CACHE_DTYPE
(
query
.
dtype
(),
kv_cache_dtype
,
CALL_V2_LAUNCHER_BLOCK_SIZE
)
}
...
...
@@ -958,10 +967,10 @@ void paged_attention_v1_opt_tc_with_mask(
const
int64_t
blocksparse_head_sliding_step
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
// [num_seqs, max_seq_len]
const
int64_t
attn_masks_stride
)
{
paged_attention_v2_opt_tc_with_mask
(
out
,
out
,
out
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
paged_attention_v2_opt_tc_with_mask
(
out
,
out
,
out
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
);
}
#undef WARP_SIZE
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment