Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c6b92879
Unverified
Commit
c6b92879
authored
Aug 13, 2025
by
Michael Goin
Committed by
GitHub
Aug 12, 2025
Browse files
Force TRTLLM attention for gpt-oss on SM100 (#22678)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
b1361c72
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
9 deletions
+20
-9
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+1
-4
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+8
-0
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+7
-4
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+4
-1
No files found.
vllm/model_executor/models/gpt_oss.py
View file @
c6b92879
...
...
@@ -8,7 +8,6 @@ import torch.distributed as dist
from
torch
import
nn
from
transformers
import
GptOssConfig
from
vllm
import
envs
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
...
...
@@ -70,11 +69,9 @@ class OAIAttention(nn.Module):
tp_size
=
get_tensor_model_parallel_world_size
()
attention_sink_dtype
=
(
torch
.
float32
if
envs
.
VLLM_USE_TRTLLM_ATTENTION
else
torch
.
bfloat16
)
self
.
sinks
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
config
.
num_attention_heads
//
tp_size
,
dtype
=
attention_sink_dtype
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
False
))
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
...
...
vllm/utils/flashinfer.py
View file @
c6b92879
...
...
@@ -154,6 +154,7 @@ def use_trtllm_attention(
num_qo_heads
:
Optional
[
int
],
num_kv_heads
:
Optional
[
int
],
attn_head_size
:
Optional
[
int
],
has_sinks
:
bool
=
False
,
)
->
bool
:
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if
not
(
current_platform
.
is_device_capability
(
100
)
...
...
@@ -165,6 +166,13 @@ def use_trtllm_attention(
or
num_qo_heads
%
num_kv_heads
!=
0
):
return
False
# If sinks are being used, we must use TRTLLM attention as it's
# the only backend that supports them
if
has_sinks
:
logger
.
info_once
(
"Using TRTLLM attention (required for attention sinks)."
)
return
True
env_value
=
envs
.
VLLM_USE_TRTLLM_ATTENTION
if
env_value
is
not
None
:
logger
.
info_once
(
"VLLM_USE_TRTLLM_ATTENTION is set to %s"
,
env_value
)
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
c6b92879
...
...
@@ -523,14 +523,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_kv_heads
=
self
.
kv_cache_spec
.
num_kv_heads
head_dim
=
self
.
kv_cache_spec
.
head_size
# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks
=
self
.
global_hyperparameters
.
has_sinks
# currently prefill trtllm attention does not support fp8 kv cache
prefill_use_trtllm
=
not
cache_dtype
.
startswith
(
"fp8"
)
\
and
use_trtllm_attention
(
num_prefill_tokens
,
max_seq_len
,
cache_dtype
,
num_qo_heads
,
num_kv_heads
,
head_dim
)
num_qo_heads
,
num_kv_heads
,
head_dim
,
has_sinks
)
decode_use_trtllm
=
use_trtllm_attention
(
num_decode_tokens
,
max_seq_len
,
cache_dtype
,
num_qo_heads
,
num_kv_heads
,
head_dim
)
num_qo_heads
,
num_kv_heads
,
head_dim
,
has_sinks
)
attn_metadata
=
FlashInferMetadata
(
num_actual_tokens
=
num_actual_tokens
,
...
...
@@ -642,9 +645,9 @@ class FlashInferImpl(AttentionImpl):
f
"heads in the layer. Expected
{
num_heads
}
, but got "
f
"
{
sinks
.
shape
[
0
]
}
."
)
# Cast sinks to float32 if needed (FlashInfer requirement)
if
sinks
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"Sinks must be of type float32, but got "
f
"
{
sinks
.
dtype
}
."
)
sinks
=
sinks
.
to
(
torch
.
float32
)
self
.
sinks
=
sinks
def
forward
(
...
...
vllm/v1/attention/backends/utils.py
View file @
c6b92879
...
...
@@ -285,6 +285,7 @@ class PerLayerParameters:
window_left
:
int
logits_soft_cap
:
Optional
[
float
]
sm_scale
:
float
has_sinks
:
bool
=
False
def
get_per_layer_parameters
(
...
...
@@ -307,9 +308,11 @@ def get_per_layer_parameters(
window_left
=
window_size
[
0
]
if
window_size
is
not
None
else
-
1
logits_soft_cap
=
getattr
(
impl
,
"logits_soft_cap"
,
None
)
sm_scale
=
impl
.
scale
has_sinks
=
getattr
(
impl
,
"sinks"
,
None
)
is
not
None
per_layer_params
[
key
]
=
PerLayerParameters
(
window_left
,
logits_soft_cap
,
sm_scale
)
logits_soft_cap
,
sm_scale
,
has_sinks
)
return
per_layer_params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment