Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
05d96d79
Commit
05d96d79
authored
Mar 26, 2026
by
Vadim Gimpelson
Committed by
khluu
Mar 26, 2026
Browse files
merge
Signed-off-by:
khluu
<
khluu000@gmail.com
>
parent
ccbc5ac4
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
73 additions
and
10 deletions
+73
-10
tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-DEP2.yaml
tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-DEP2.yaml
+2
-1
tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-FP8-DEP2.yaml
tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-FP8-DEP2.yaml
+2
-1
tests/evals/gsm8k/configs/Qwen3.5-397B-A17B-NVFP4-DEP2.yaml
tests/evals/gsm8k/configs/Qwen3.5-397B-A17B-NVFP4-DEP2.yaml
+9
-0
tests/evals/gsm8k/configs/models-qwen35-blackwell.txt
tests/evals/gsm8k/configs/models-qwen35-blackwell.txt
+1
-0
tests/evals/gsm8k/test_gsm8k_correctness.py
tests/evals/gsm8k/test_gsm8k_correctness.py
+4
-6
vllm/config/vllm.py
vllm/config/vllm.py
+21
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+10
-1
vllm/model_executor/layers/quantization/input_quant_fp8.py
vllm/model_executor/layers/quantization/input_quant_fp8.py
+1
-0
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+5
-1
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+18
-0
No files found.
tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-DEP2.yaml
View file @
05d96d79
model_name
:
"
Qwen/Qwen3.5-35B-A3B"
accuracy_threshold
:
0.86
accuracy_threshold
:
0.84
tolerance
:
0.03
num_questions
:
1319
num_fewshot
:
5
server_args
:
>-
...
...
tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-FP8-DEP2.yaml
View file @
05d96d79
model_name
:
"
Qwen/Qwen3.5-35B-A3B-FP8"
accuracy_threshold
:
0.86
accuracy_threshold
:
0.79
tolerance
:
0.03
num_questions
:
1319
num_fewshot
:
5
server_args
:
>-
...
...
tests/evals/gsm8k/configs/Qwen3.5-397B-A17B-NVFP4-DEP2.yaml
0 → 100644
View file @
05d96d79
model_name
:
"
nvidia/Qwen3.5-397B-A17B-NVFP4"
accuracy_threshold
:
0.88
tolerance
:
0.03
num_questions
:
1319
num_fewshot
:
5
server_args
:
>-
--max-model-len 4096
--data-parallel-size 2
--enable-expert-parallel
tests/evals/gsm8k/configs/models-qwen35-blackwell.txt
View file @
05d96d79
Qwen3.5-35B-A3B-DEP2.yaml
Qwen3.5-35B-A3B-FP8-DEP2.yaml
Qwen3.5-397B-A17B-NVFP4-DEP2.yaml
\ No newline at end of file
tests/evals/gsm8k/test_gsm8k_correctness.py
View file @
05d96d79
...
...
@@ -19,8 +19,6 @@ from vllm.platforms import current_platform
from
.gsm8k_eval
import
evaluate_gsm8k
TOL
=
0.08
# Absolute tolerance for accuracy comparison
def
run_gsm8k_eval
(
eval_config
:
dict
,
server_url
:
str
)
->
dict
:
"""Run GSM8K evaluation using our isolated script."""
...
...
@@ -99,20 +97,20 @@ def test_gsm8k_correctness(config_filename):
measured_metric
=
results
[
"accuracy"
]
expected_metric
=
eval_config
[
"accuracy_threshold"
]
tol
=
eval_config
.
get
(
"tolerance"
,
0.08
)
print
(
f
"GSM8K Results for
{
eval_config
[
'model_name'
]
}
:"
)
print
(
f
" Measured metric:
{
measured_metric
:.
4
f
}
"
)
print
(
f
" Expected metric:
{
expected_metric
:.
4
f
}
"
)
print
(
f
" Tolerance:
{
TOL
:.
4
f
}
"
)
print
(
f
" Tolerance:
{
tol
:.
4
f
}
"
)
print
(
f
" Questions:
{
results
[
'num_questions'
]
}
"
)
print
(
f
" Invalid rate:
{
results
[
'invalid_rate'
]:.
3
f
}
"
)
print
(
f
" Latency:
{
results
[
'latency'
]:.
1
f
}
s"
)
print
(
f
" QPS:
{
results
[
'questions_per_second'
]:.
1
f
}
"
)
# Verify metric is within tolerance
assert
measured_metric
>=
expected_metric
-
TOL
,
(
assert
measured_metric
>=
expected_metric
-
tol
,
(
f
"GSM8K metric too low:
{
measured_metric
:.
4
f
}
< "
f
"
{
expected_metric
:.
4
f
}
-
{
TOL
:.
4
f
}
=
{
expected_metric
-
TOL
:.
4
f
}
"
f
"
{
expected_metric
:.
4
f
}
-
{
tol
:.
4
f
}
=
{
expected_metric
-
tol
:.
4
f
}
"
)
print
(
f
"✅ GSM8K test passed for
{
eval_config
[
'model_name'
]
}
"
)
vllm/config/vllm.py
View file @
05d96d79
...
...
@@ -682,6 +682,27 @@ class VllmConfig:
self
.
model_config
,
self
.
load_config
)
if
(
self
.
quant_config
is
not
None
and
self
.
model_config
is
not
None
and
hasattr
(
self
.
quant_config
,
"use_deep_gemm"
)
and
self
.
quant_config
.
use_deep_gemm
is
None
):
from
vllm.utils.deep_gemm
import
should_auto_disable_deep_gemm
model_type
=
getattr
(
self
.
model_config
.
hf_text_config
,
"model_type"
,
None
)
if
should_auto_disable_deep_gemm
(
model_type
):
self
.
quant_config
.
use_deep_gemm
=
False
logger
.
warning_once
(
"Auto-disabled DeepGemm for model_type=%s on Blackwell. "
"DeepGemm E8M0 scale format causes accuracy degradation "
"for this architecture. Falling back to CUTLASS. "
"To disable DeepGemm globally, set VLLM_USE_DEEP_GEMM=0."
,
model_type
,
)
from
vllm.v1.executor.abstract
import
Executor
executor_backend
=
self
.
parallel_config
.
distributed_executor_backend
executor_supports_async_sched
=
executor_backend
in
(
"mp"
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
05d96d79
...
...
@@ -135,6 +135,7 @@ class Fp8Config(QuantizationConfig):
f
"
{
activation_scheme
}
activation scheme."
)
self
.
weight_block_size
=
weight_block_size
self
.
use_deep_gemm
:
bool
|
None
=
None
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
...
...
@@ -291,7 +292,10 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
use_marlin
=
False
self
.
use_aiter_and_is_supported
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
self
.
use_deep_gemm
=
is_deep_gemm_supported
()
if
self
.
quant_config
.
use_deep_gemm
is
not
None
:
self
.
use_deep_gemm
=
self
.
quant_config
.
use_deep_gemm
else
:
self
.
use_deep_gemm
=
is_deep_gemm_supported
()
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
block_quant
=
self
.
weight_block_size
is
not
None
...
...
@@ -305,6 +309,7 @@ class Fp8LinearMethod(LinearMethodBase):
act_quant_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
]),
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
use_deep_gemm
=
self
.
use_deep_gemm
,
)
else
:
# Use per-token quantization for better perf if dynamic and cutlass
...
...
@@ -432,6 +437,7 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
layer
.
input_scale
=
None
<<<<<<<
HEAD
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
,
size_k_first
,
input_dtype
=
self
.
marlin_input_dtype
...
...
@@ -441,6 +447,9 @@ class Fp8LinearMethod(LinearMethodBase):
return
if
self
.
block_quant
:
=======
if
self
.
block_quant
and
self
.
use_deep_gemm
:
>>>>>>>
52069012
f
([
Bugfix
]
Fix
DeepGemm
E8M0
accuracy
degradation
for
Qwen3
.
5
FP8
on
Blackwell
(
#38083))
maybe_post_process_fp8_weight_block
(
layer
)
def
apply
(
...
...
vllm/model_executor/layers/quantization/input_quant_fp8.py
View file @
05d96d79
...
...
@@ -91,6 +91,7 @@ class QuantFP8(CustomOp):
if
(
self
.
is_group_quant
and
self
.
use_ue8m0
and
self
.
use_deep_gemm_supported
and
(
DeepGemmQuantScaleFMT
.
from_oracle
()
==
DeepGemmQuantScaleFMT
.
UE8M0
)
):
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
05d96d79
...
...
@@ -356,10 +356,14 @@ class W8A8BlockFp8LinearOp:
act_quant_group_shape
:
GroupShape
,
cutlass_block_fp8_supported
:
bool
=
CUTLASS_BLOCK_FP8_SUPPORTED
,
use_aiter_and_is_supported
:
bool
=
False
,
use_deep_gemm
:
bool
|
None
=
None
,
):
self
.
weight_group_shape
=
weight_group_shape
self
.
act_quant_group_shape
=
act_quant_group_shape
self
.
is_deep_gemm_supported
=
is_deep_gemm_supported
()
if
use_deep_gemm
is
not
None
:
self
.
is_deep_gemm_supported
=
use_deep_gemm
else
:
self
.
is_deep_gemm_supported
=
is_deep_gemm_supported
()
self
.
is_hopper
=
current_platform
.
is_device_capability
(
90
)
self
.
use_deep_gemm_e8m0
=
is_deep_gemm_e8m0_used
()
self
.
is_flashinfer_supported
=
is_flashinfer_fp8_blockscale_gemm_supported
()
...
...
vllm/utils/deep_gemm.py
View file @
05d96d79
...
...
@@ -23,6 +23,24 @@ from vllm.platforms import current_platform
from
vllm.utils.import_utils
import
has_deep_gemm
from
vllm.utils.math_utils
import
cdiv
_DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES
:
set
[
str
]
=
{
"qwen3_5_text"
,
"qwen3_5_moe_text"
,
}
def
should_auto_disable_deep_gemm
(
model_type
:
str
|
None
)
->
bool
:
"""Check if DeepGemm should be auto-disabled for this model on Blackwell.
Returns True if the model is known to have accuracy degradation with
DeepGemm's E8M0 scale format on Blackwell GPUs (SM100+).
"""
if
model_type
is
None
:
return
False
if
not
current_platform
.
is_device_capability_family
(
100
):
return
False
return
model_type
in
_DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES
class
DeepGemmQuantScaleFMT
(
Enum
):
# Float32 scales in Float32 tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment