Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8cc26acd
Unverified
Commit
8cc26acd
authored
Jan 18, 2026
by
Isotr0py
Committed by
GitHub
Jan 17, 2026
Browse files
[Performance] Improve Triton prefill attention kernel's performance (#32403)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
4a6af881
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
48 deletions
+33
-48
tests/models/language/pooling/test_token_classification.py
tests/models/language/pooling/test_token_classification.py
+2
-2
vllm/utils/math_utils.py
vllm/utils/math_utils.py
+4
-0
vllm/v1/attention/ops/triton_prefill_attention.py
vllm/v1/attention/ops/triton_prefill_attention.py
+27
-46
No files found.
tests/models/language/pooling/test_token_classification.py
View file @
8cc26acd
...
@@ -46,7 +46,7 @@ def test_bert_models(
...
@@ -46,7 +46,7 @@ def test_bert_models(
for
hf_output
,
vllm_output
in
zip
(
hf_outputs
,
vllm_outputs
):
for
hf_output
,
vllm_output
in
zip
(
hf_outputs
,
vllm_outputs
):
hf_output
=
hf_output
.
detach
().
clone
().
cpu
().
float
()
hf_output
=
hf_output
.
detach
().
clone
().
cpu
().
float
()
vllm_output
=
vllm_output
.
detach
().
clone
().
cpu
().
float
()
vllm_output
=
vllm_output
.
detach
().
clone
().
cpu
().
float
()
torch
.
testing
.
assert_close
(
hf_output
,
vllm_output
,
atol
=
1
.2e-2
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
hf_output
,
vllm_output
,
atol
=
3
.2e-2
,
rtol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"disham993/electrical-ner-ModernBERT-base"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"disham993/electrical-ner-ModernBERT-base"
])
...
@@ -86,7 +86,7 @@ def test_modernbert_models(
...
@@ -86,7 +86,7 @@ def test_modernbert_models(
for
hf_output
,
vllm_output
in
zip
(
hf_outputs
,
vllm_outputs
):
for
hf_output
,
vllm_output
in
zip
(
hf_outputs
,
vllm_outputs
):
hf_output
=
hf_output
.
detach
().
clone
().
cpu
().
float
()
hf_output
=
hf_output
.
detach
().
clone
().
cpu
().
float
()
vllm_output
=
vllm_output
.
detach
().
clone
().
cpu
().
float
()
vllm_output
=
vllm_output
.
detach
().
clone
().
cpu
().
float
()
torch
.
testing
.
assert_close
(
hf_output
,
vllm_output
,
atol
=
1
.2e-2
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
hf_output
,
vllm_output
,
atol
=
3
.2e-2
,
rtol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"bd2lcco/Qwen3-0.6B-finetuned"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"bd2lcco/Qwen3-0.6B-finetuned"
])
...
...
vllm/utils/math_utils.py
View file @
8cc26acd
...
@@ -2,6 +2,10 @@
...
@@ -2,6 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Math utility functions for vLLM."""
"""Math utility functions for vLLM."""
# Approximate value of 1/ln(2), used for log/exp base conversion
# Best FP32 approximation: 1.4426950216 (hex 0x3FB8AA3B)
RCP_LN2
=
1.4426950216
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
"""Ceiling division."""
"""Ceiling division."""
...
...
vllm/v1/attention/ops/triton_prefill_attention.py
View file @
8cc26acd
...
@@ -30,6 +30,7 @@ import torch
...
@@ -30,6 +30,7 @@ import torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
RCP_LN2
@
triton
.
jit
@
triton
.
jit
...
@@ -110,15 +111,7 @@ def _fwd_kernel(
...
@@ -110,15 +111,7 @@ def _fwd_kernel(
end_n_limit
=
block_mask
*
end_n
end_n_limit
=
block_mask
*
end_n
for
start_n
in
range
(
start_n_limit
,
end_n_limit
,
BLOCK_N
):
for
start_n
in
range
(
start_n_limit
,
end_n_limit
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- prepare attention mask ----
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
# Apply attention mask (causal + bidirectional sliding window)
# Position indices in the sequence
# Position indices in the sequence
pos_q
=
offs_m
[:,
None
]
# Query positions [BLOCK_M, 1]
pos_q
=
offs_m
[:,
None
]
# Query positions [BLOCK_M, 1]
pos_k
=
start_n
+
offs_n
[
None
,
:]
# Key positions [1, BLOCK_N]
pos_k
=
start_n
+
offs_n
[
None
,
:]
# Key positions [1, BLOCK_N]
...
@@ -141,53 +134,38 @@ def _fwd_kernel(
...
@@ -141,53 +134,38 @@ def _fwd_kernel(
if
sliding_mask_k
is
not
None
:
if
sliding_mask_k
is
not
None
:
mask
&=
sliding_mask_k
mask
&=
sliding_mask_k
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
qk
+=
tl
.
where
(
mask
,
0
,
float
(
"-inf"
))
# -- compute qk ----
q
k
+
=
tl
.
dot
(
q
,
k
)
k
=
tl
.
load
(
qk
*=
sm_scale
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
pos_k
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
]),
# -- compute m_ij, p, l_ij
other
=
0.0
,
m_ij
=
tl
.
max
(
qk
,
1
)
)
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
qk
=
tl
.
dot
(
q
,
k
)
m_ij_valid_mask
=
m_ij
>
float
(
"-inf"
)
qk
=
tl
.
where
(
mask
,
qk
*
sm_scale
,
-
1.0e8
)
m_ij
_masked
=
tl
.
where
(
m_ij_valid_mask
,
m_ij
,
0.0
)
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
)
)
# -- compute p and l_ij --
qk
-=
m_ij
[:,
None
]
p
=
tl
.
exp
(
qk
-
m_ij_masked
[:,
None
]
)
p
=
tl
.
math
.
exp
2
(
qk
)
l_ij
=
tl
.
sum
(
p
,
1
)
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
math
.
exp2
(
m_i
-
m_ij
)
m_i_new_mask
=
m_i_new
>
float
(
"-inf"
)
l_i
=
l_i
*
alpha
+
l_ij
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
# mask alpha and beta for sliding window
alpha
=
tl
.
where
(
m_i_new_mask
,
alpha
,
1.0
)
beta
=
tl
.
where
(
m_i_new_mask
,
beta
,
0.0
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# -- update output accumulator --
# scale p
acc
=
acc
*
alpha
[:,
None
]
# For sliding window there's a chance the l_i_new is 0 due to masking
# the entire row. We need to set l_i_new 1 to avoid zero division
l_i_new_mask
=
(
l_i_new
!=
0.0
)
&
(
m_i_new_mask
>
float
(
"-inf"
))
l_i_new_safe
=
tl
.
where
(
l_i_new_mask
,
l_i_new
,
1.0
)
p_scale
=
beta
/
l_i_new_safe
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new_safe
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
# update acc
v
=
tl
.
load
(
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
mask
=
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
other
=
0.0
,
)
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
+
=
tl
.
dot
(
p
,
v
)
acc
=
tl
.
dot
(
p
,
v
,
acc
)
# update m_i
and l_i
# update m_i
l
_i
=
l
_i
_new
m
_i
=
m
_i
j
m_i
=
m_i_new
# initialize pointers to output
acc
=
acc
/
l_i
[:,
None
]
off_o
=
(
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
cur_head
*
stride_oh
...
@@ -234,6 +212,9 @@ def context_attention_fwd(
...
@@ -234,6 +212,9 @@ def context_attention_fwd(
Lq
,
Lk
,
_
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
Lq
,
Lk
,
_
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
sm_scale
=
1.0
/
(
Lq
**
0.5
)
if
softmax_scale
is
None
else
softmax_scale
sm_scale
=
1.0
/
(
Lq
**
0.5
)
if
softmax_scale
is
None
else
softmax_scale
# rescale with 1/ln(2) for triton exp2
sm_scale
*=
RCP_LN2
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment