Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
007dd908
Unverified
Commit
007dd908
authored
Aug 12, 2025
by
Yongye Zhu
Committed by
GitHub
Aug 12, 2025
Browse files
[gpt-oss] Enable gpt-oss on ampere (#22714)
Signed-off-by:
Yongye Zhu
<
zyy1102000@gmail.com
>
parent
b8a9d0e4
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
26 additions
and
17 deletions
+26
-17
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
..._dummy_platform/vllm_add_dummy_platform/dummy_platform.py
+3
-2
vllm/attention/layer.py
vllm/attention/layer.py
+3
-1
vllm/attention/selector.py
vllm/attention/selector.py
+4
-1
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+1
-1
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+2
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+5
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+2
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+2
-2
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+2
-2
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+2
-2
No files found.
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
View file @
007dd908
...
...
@@ -25,5 +25,6 @@ class DummyPlatform(Platform):
compilation_config
.
custom_ops
=
[
"all"
]
def
get_attn_backend_cls
(
self
,
backend_name
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
):
return
"vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend"
# noqa E501
\ No newline at end of file
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
,
has_sink
):
return
"vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend"
# noqa E501
vllm/attention/layer.py
View file @
007dd908
...
...
@@ -138,6 +138,7 @@ class Attention(nn.Module):
self
.
head_size
=
head_size
self
.
num_kv_heads
=
num_kv_heads
self
.
sliding_window
=
sliding_window
self
.
has_sink
=
extra_impl_args
.
get
(
"sinks"
)
is
not
None
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
quant_config
else
None
...
...
@@ -165,7 +166,8 @@ class Attention(nn.Module):
kv_cache_dtype
,
block_size
,
is_attention_free
,
use_mla
=
use_mla
)
use_mla
=
use_mla
,
has_sink
=
self
.
has_sink
)
else
:
self
.
attn_backend
=
attn_backend
...
...
vllm/attention/selector.py
View file @
007dd908
...
...
@@ -144,6 +144,7 @@ def get_attn_backend(
block_size
:
int
,
is_attention_free
:
bool
=
False
,
use_mla
:
bool
=
False
,
has_sink
:
bool
=
False
,
)
->
type
[
AttentionBackend
]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
...
...
@@ -158,6 +159,7 @@ def get_attn_backend(
is_attention_free
=
is_attention_free
,
use_v1
=
envs
.
VLLM_USE_V1
,
use_mla
=
use_mla
,
has_sink
=
has_sink
,
)
...
...
@@ -170,6 +172,7 @@ def _cached_get_attn_backend(
is_attention_free
:
bool
,
use_v1
:
bool
=
False
,
use_mla
:
bool
=
False
,
has_sink
:
bool
=
False
,
)
->
type
[
AttentionBackend
]:
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
...
...
@@ -201,7 +204,7 @@ def _cached_get_attn_backend(
# get device-specific attn_backend
attention_cls
=
current_platform
.
get_attn_backend_cls
(
selected_backend
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
)
use_mla
,
has_sink
)
if
not
attention_cls
:
raise
ValueError
(
f
"Invalid attention backend for
{
current_platform
.
device_name
}
"
)
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
007dd908
...
...
@@ -42,7 +42,7 @@ class Mxfp4Config(QuantizationConfig):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
9
0
return
8
0
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
...
...
vllm/platforms/cpu.py
View file @
007dd908
...
...
@@ -91,8 +91,8 @@ class CpuPlatform(Platform):
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
:
_Backend
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
)
->
str
:
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
has_sink
:
bool
)
->
str
:
if
selected_backend
and
selected_backend
!=
_Backend
.
TORCH_SDPA
:
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
if
use_mla
:
...
...
vllm/platforms/cuda.py
View file @
007dd908
...
...
@@ -222,8 +222,8 @@ class CudaPlatformBase(Platform):
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
)
->
str
:
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
,
has_sink
)
->
str
:
if
use_mla
:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
...
...
@@ -321,6 +321,9 @@ class CudaPlatformBase(Platform):
# FlashAttention is the default for SM 8.0+ GPUs
if
cls
.
has_device_capability
(
80
):
if
has_sink
:
logger
.
info_once
(
"Using Triton backend on V1 engine."
)
return
TRITON_ATTN_VLLM_V1
if
is_default_backend_supported
:
=
is_attn_backend_supported
(
FLASH_ATTN_V1
,
head_size
,
dtype
,
allow_import_error
=
False
):
...
...
vllm/platforms/interface.py
View file @
007dd908
...
...
@@ -196,8 +196,8 @@ class Platform:
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
:
_Backend
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
)
->
str
:
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
has_sink
:
bool
)
->
str
:
"""Get the attention backend class of a device."""
return
""
...
...
vllm/platforms/rocm.py
View file @
007dd908
...
...
@@ -188,8 +188,8 @@ class RocmPlatform(Platform):
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
)
->
str
:
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
,
has_sink
)
->
str
:
if
use_mla
:
from
vllm.attention.backends.rocm_aiter_mla
import
(
is_aiter_mla_enabled
)
...
...
vllm/platforms/tpu.py
View file @
007dd908
...
...
@@ -46,8 +46,8 @@ class TpuPlatform(Platform):
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
:
_Backend
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
)
->
str
:
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
has_sink
)
->
str
:
if
(
selected_backend
!=
_Backend
.
PALLAS
and
selected_backend
!=
_Backend
.
PALLAS_VLLM_V1
):
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
...
...
vllm/platforms/xpu.py
View file @
007dd908
...
...
@@ -35,8 +35,8 @@ class XPUPlatform(Platform):
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
:
_Backend
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
)
->
str
:
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
has_sink
:
bool
)
->
str
:
if
selected_backend
is
not
None
and
selected_backend
!=
_Backend
.
IPEX
:
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
use_v1
=
envs
.
VLLM_USE_V1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment