Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5a1271d8
Unverified
Commit
5a1271d8
authored
Nov 12, 2025
by
xuebwang-amd
Committed by
GitHub
Nov 11, 2025
Browse files
[Quantization] fix attention quantization of gpt_oss model (#27334)
Signed-off-by:
xuebwang-amd
<
xuebwang@amd.com
>
parent
05576df8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
101 additions
and
4 deletions
+101
-4
tests/models/quantization/test_gpt_oss_attn_quantization.py
tests/models/quantization/test_gpt_oss_attn_quantization.py
+80
-0
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+13
-2
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+8
-2
No files found.
tests/models/quantization/test_gpt_oss_attn_quantization.py
0 → 100644
View file @
5a1271d8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test attention quantization of gpt-oss model.
The qkv_proj and o_proj in self_attention can be either quantized or excluded.
Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`.
"""
import
importlib
import
importlib.metadata
from
dataclasses
import
dataclass
import
huggingface_hub
import
lm_eval
import
pytest
from
packaging
import
version
MODEL_NAMES
=
[
"amd/gpt-oss-20b-customized-attention-quantization"
]
QUARK_MXFP4_AVAILABLE
=
importlib
.
util
.
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
)
)
>=
version
.
parse
(
"0.8.99"
)
def
has_huggingface_access
(
repo
):
try
:
huggingface_hub
.
list_repo_refs
(
repo
)
return
True
except
huggingface_hub
.
errors
.
RepositoryNotFoundError
:
return
False
HF_HUB_AMD_ORG_ACCESS
=
all
(
[
has_huggingface_access
(
model_name
)
for
model_name
in
MODEL_NAMES
]
)
@
dataclass
class
ModelCase
:
model_id
:
str
tp
:
int
@
dataclass
class
EvaluationConfig
:
model_name
:
str
def
get_model_args
(
self
)
->
str
:
return
(
f
"pretrained=
{
self
.
model_name
}
,"
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False"
)
EXPECTED_ACCURACIES
=
{
"arc_challenge"
:
0.20
}
@
pytest
.
mark
.
skipif
(
not
QUARK_MXFP4_AVAILABLE
,
reason
=
"amd-quark>=0.9 is not available"
)
@
pytest
.
mark
.
skipif
(
not
HF_HUB_AMD_ORG_ACCESS
,
reason
=
"Read access to huggingface.co/amd is required for this test."
,
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODEL_NAMES
)
@
pytest
.
mark
.
parametrize
(
"task_name, expected_accuracy"
,
EXPECTED_ACCURACIES
.
items
())
def
test_gpt_oss_attention_quantization
(
model_name
:
str
,
task_name
:
str
,
expected_accuracy
:
float
):
measured_accuracy
=
lm_eval
.
simple_evaluate
(
model
=
"vllm"
,
model_args
=
EvaluationConfig
(
model_name
).
get_model_args
(),
tasks
=
task_name
,
batch_size
=
"auto"
,
)[
"results"
][
task_name
][
"acc,none"
]
rtol
=
0.05
assert
(
measured_accuracy
-
rtol
<
expected_accuracy
and
measured_accuracy
+
rtol
>
expected_accuracy
),
f
"Expected:
{
expected_accuracy
}
| Measured:
{
measured_accuracy
}
"
vllm/model_executor/layers/quantization/mxfp4.py
View file @
5a1271d8
...
@@ -190,14 +190,25 @@ class Mxfp4Config(QuantizationConfig):
...
@@ -190,14 +190,25 @@ class Mxfp4Config(QuantizationConfig):
fused_mapping
=
self
.
packed_modules_mapping
,
fused_mapping
=
self
.
packed_modules_mapping
,
):
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
raise
NotImplementedError
(
"Mxfp4 linear layer is not implemented"
)
# TODO: Add support for MXFP4 Linear Method.
# MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
# if you are interested in enabling MXFP4 here.
logger
.
warning_once
(
"MXFP4 linear layer is not implemented - falling back to "
"UnquantizedLinearMethod."
)
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
if
current_platform
.
is_xpu
():
if
current_platform
.
is_xpu
():
return
IpexMxfp4MoEMethod
(
layer
.
moe_config
)
return
IpexMxfp4MoEMethod
(
layer
.
moe_config
)
else
:
else
:
return
Mxfp4MoEMethod
(
layer
.
moe_config
)
return
Mxfp4MoEMethod
(
layer
.
moe_config
)
elif
isinstance
(
layer
,
Attention
):
elif
isinstance
(
layer
,
Attention
):
raise
NotImplementedError
(
"Mxfp4 attention layer is not implemented"
)
# TODO: Add support for MXFP4 Attention.
logger
.
warning_once
(
"MXFP4 attention layer is not implemented. "
"Skipping quantization for this layer."
)
return
None
return
None
...
...
vllm/model_executor/models/gpt_oss.py
View file @
5a1271d8
...
@@ -198,6 +198,7 @@ class TransformerBlock(torch.nn.Module):
...
@@ -198,6 +198,7 @@ class TransformerBlock(torch.nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
quant_config
:
QuantizationConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -207,7 +208,10 @@ class TransformerBlock(torch.nn.Module):
...
@@ -207,7 +208,10 @@ class TransformerBlock(torch.nn.Module):
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
attn
=
OAIAttention
(
self
.
attn
=
OAIAttention
(
config
,
prefix
=
f
"
{
prefix
}
.attn"
,
cache_config
=
cache_config
config
,
prefix
=
f
"
{
prefix
}
.attn"
,
quant_config
=
quant_config
,
cache_config
=
cache_config
,
)
)
self
.
mlp
=
MLPBlock
(
vllm_config
,
self
.
layer_idx
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
mlp
=
MLPBlock
(
vllm_config
,
self
.
layer_idx
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
...
@@ -243,6 +247,7 @@ class GptOssModel(nn.Module):
...
@@ -243,6 +247,7 @@ class GptOssModel(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
quant_config
=
vllm_config
.
quant_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
config
.
hidden_size
=
self
.
config
.
hidden_size
self
.
config
.
hidden_size
=
self
.
config
.
hidden_size
self
.
embedding
=
VocabParallelEmbedding
(
self
.
embedding
=
VocabParallelEmbedding
(
...
@@ -254,6 +259,7 @@ class GptOssModel(nn.Module):
...
@@ -254,6 +259,7 @@ class GptOssModel(nn.Module):
lambda
prefix
:
TransformerBlock
(
lambda
prefix
:
TransformerBlock
(
vllm_config
,
vllm_config
,
prefix
=
prefix
,
prefix
=
prefix
,
quant_config
=
self
.
quant_config
,
),
),
prefix
=
f
"
{
prefix
}
.layers"
,
prefix
=
f
"
{
prefix
}
.layers"
,
)
)
...
@@ -645,7 +651,7 @@ class GptOssModel(nn.Module):
...
@@ -645,7 +651,7 @@ class GptOssModel(nn.Module):
class
GptOssForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsEagle3
,
SupportsLoRA
):
class
GptOssForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsEagle3
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
packed_modules_mapping
=
{
"qkv
_proj
"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
orig_to_new_substr
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment