Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
10ee1c64
Unverified
Commit
10ee1c64
authored
Dec 16, 2025
by
Michael Goin
Committed by
GitHub
Dec 16, 2025
Browse files
[CI] Generalize gsm8k test args and add Qwen3-Next MTP B200 test (#30723)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
66c3537e
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
78 additions
and
57 deletions
+78
-57
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-2
tests/evals/gsm8k/README.md
tests/evals/gsm8k/README.md
+9
-4
tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml
tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml
+1
-2
tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml
...vals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml
+1
-1
tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml
tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml
+1
-1
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
+1
-1
tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
...als/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
+1
-1
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
+1
-1
tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml
tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml
+1
-2
tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
+12
-0
tests/evals/gsm8k/configs/models-blackwell.txt
tests/evals/gsm8k/configs/models-blackwell.txt
+1
-0
tests/evals/gsm8k/conftest.py
tests/evals/gsm8k/conftest.py
+3
-5
tests/evals/gsm8k/test_gsm8k_correctness.py
tests/evals/gsm8k/test_gsm8k_correctness.py
+41
-29
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+3
-8
No files found.
.buildkite/test-pipeline.yaml
View file @
10ee1c64
...
@@ -654,7 +654,7 @@ steps:
...
@@ -654,7 +654,7 @@ steps:
-
vllm/model_executor/layers/quantization
-
vllm/model_executor/layers/quantization
autorun_on_main
:
true
autorun_on_main
:
true
commands
:
commands
:
-
pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
--tp-size=1
-
pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
-
label
:
OpenAI API correctness
# 22min
-
label
:
OpenAI API correctness
# 22min
timeout_in_minutes
:
30
timeout_in_minutes
:
30
...
@@ -1064,7 +1064,7 @@ steps:
...
@@ -1064,7 +1064,7 @@ steps:
-
csrc/
-
csrc/
-
vllm/model_executor/layers/quantization
-
vllm/model_executor/layers/quantization
commands
:
commands
:
-
pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
--tp-size=1
-
pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
##### 1 GPU test #####
##### 1 GPU test #####
##### multi gpus test #####
##### multi gpus test #####
...
...
tests/evals/gsm8k/README.md
View file @
10ee1c64
...
@@ -7,9 +7,8 @@ This directory contains a replacement for the lm-eval-harness GSM8K evaluation,
...
@@ -7,9 +7,8 @@ This directory contains a replacement for the lm-eval-harness GSM8K evaluation,
### Run tests with pytest (like buildkite)
### Run tests with pytest (like buildkite)
```
bash
```
bash
pytest
-s
-v
tests/gsm8k/test_gsm8k_correctness.py
\
pytest
-s
-v
tests/evals/gsm8k/test_gsm8k_correctness.py
\
--config-list-file
=
configs/models-small.txt
\
--config-list-file
=
configs/models-small.txt
--tp-size
=
1
```
```
### Run standalone evaluation script
### Run standalone evaluation script
...
@@ -31,5 +30,11 @@ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
...
@@ -31,5 +30,11 @@ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
accuracy_threshold
:
0.54
# Minimum expected accuracy
accuracy_threshold
:
0.54
# Minimum expected accuracy
num_questions
:
1319
# Number of questions (default: full test set)
num_questions
:
1319
# Number of questions (default: full test set)
num_fewshot
:
5
# Few-shot examples from train set
num_fewshot
:
5
# Few-shot examples from train set
max_model_len
:
4096
# Model context length
server_args
:
"
--max-model-len
4096
--tensor-parallel-size
2"
# Server arguments
env
:
# Environment variables (optional)
VLLM_USE_FLASHINFER_MOE_FP4
:
"
1"
```
```
The
`server_args`
field accepts any arguments that can be passed to
`vllm serve`
.
The
`env`
field accepts a dictionary of environment variables to set for the server process.
tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml
View file @
10ee1c64
...
@@ -2,5 +2,4 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
...
@@ -2,5 +2,4 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
accuracy_threshold
:
0.72
accuracy_threshold
:
0.72
num_questions
:
1319
num_questions
:
1319
num_fewshot
:
5
num_fewshot
:
5
max_model_len
:
4096
server_args
:
"
--enforce-eager
--max-model-len
4096"
tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml
View file @
10ee1c64
...
@@ -2,4 +2,4 @@ model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
...
@@ -2,4 +2,4 @@ model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
accuracy_threshold
:
0.74
accuracy_threshold
:
0.74
num_questions
:
1319
num_questions
:
1319
num_fewshot
:
5
num_fewshot
:
5
max_model_len
:
4096
server_args
:
"
--enforce-eager
--max-model-len
4096"
\ No newline at end of file
tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml
View file @
10ee1c64
...
@@ -2,4 +2,4 @@ model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
...
@@ -2,4 +2,4 @@ model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
accuracy_threshold
:
0.31
accuracy_threshold
:
0.31
num_questions
:
1319
num_questions
:
1319
num_fewshot
:
5
num_fewshot
:
5
max_model_len
:
4096
server_args
:
"
--enforce-eager
--max-model-len
4096"
\ No newline at end of file
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
View file @
10ee1c64
...
@@ -2,4 +2,4 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
...
@@ -2,4 +2,4 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
accuracy_threshold
:
0.45
accuracy_threshold
:
0.45
num_questions
:
1319
num_questions
:
1319
num_fewshot
:
5
num_fewshot
:
5
max
_
model
_
len
:
4096
server_args
:
"
--enforce-eager
--
max
-
model
-
len
4096
"
tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
View file @
10ee1c64
...
@@ -2,4 +2,4 @@ model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
...
@@ -2,4 +2,4 @@ model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
accuracy_threshold
:
0.60
accuracy_threshold
:
0.60
num_questions
:
1319
num_questions
:
1319
num_fewshot
:
5
num_fewshot
:
5
max_model_len
:
4096
server_args
:
"
--enforce-eager
--max-model-len
4096"
\ No newline at end of file
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
View file @
10ee1c64
...
@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-0.6B-FP8"
...
@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-0.6B-FP8"
accuracy_threshold
:
0.375
accuracy_threshold
:
0.375
num_questions
:
1319
num_questions
:
1319
num_fewshot
:
5
num_fewshot
:
5
max_model_len
:
4096
server_args
:
"
--enforce-eager
--max-model-len
4096"
\ No newline at end of file
tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml
View file @
10ee1c64
...
@@ -2,5 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-FP4"
...
@@ -2,5 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-FP4"
accuracy_threshold
:
0.89
accuracy_threshold
:
0.89
num_questions
:
1319
num_questions
:
1319
num_fewshot
:
5
num_fewshot
:
5
max_model_len
:
4096
server_args
:
"
--enforce-eager
--max-model-len
4096"
tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
0 → 100644
View file @
10ee1c64
model_name
:
"
nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4"
accuracy_threshold
:
0.75
num_questions
:
1319
num_fewshot
:
5
server_args
:
>-
--enforce-eager
--max-model-len 4096
--tensor-parallel-size 2
--enable-expert-parallel
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}'
env
:
VLLM_USE_FLASHINFER_MOE_FP4
:
"
1"
tests/evals/gsm8k/configs/models-blackwell.txt
View file @
10ee1c64
...
@@ -3,3 +3,4 @@ Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
...
@@ -3,3 +3,4 @@ Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-CT.yaml
Qwen1.5-MoE-W4A16-CT.yaml
DeepSeek-V2-Lite-Instruct-FP8.yaml
DeepSeek-V2-Lite-Instruct-FP8.yaml
Qwen3-30B-A3B-NVFP4.yaml
Qwen3-30B-A3B-NVFP4.yaml
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
tests/evals/gsm8k/conftest.py
View file @
10ee1c64
...
@@ -11,14 +11,12 @@ def pytest_addoption(parser):
...
@@ -11,14 +11,12 @@ def pytest_addoption(parser):
default
=
"configs/models-small.txt"
,
default
=
"configs/models-small.txt"
,
help
=
"File containing list of config files to test"
,
help
=
"File containing list of config files to test"
,
)
)
parser
.
addoption
(
"--tp-size"
,
default
=
1
,
type
=
int
,
help
=
"Tensor parallel size"
)
def
pytest_generate_tests
(
metafunc
):
def
pytest_generate_tests
(
metafunc
):
"""Generate test parameters from config files."""
"""Generate test parameters from config files."""
if
"config_filename"
in
metafunc
.
fixturenames
:
if
"config_filename"
in
metafunc
.
fixturenames
:
config_list_file
=
metafunc
.
config
.
getoption
(
"--config-list-file"
)
config_list_file
=
metafunc
.
config
.
getoption
(
"--config-list-file"
)
tp_size
=
metafunc
.
config
.
getoption
(
"--tp-size"
)
# Handle both relative and absolute paths
# Handle both relative and absolute paths
config_list_path
=
Path
(
config_list_file
)
config_list_path
=
Path
(
config_list_file
)
...
@@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc):
...
@@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc):
# Generate test parameters
# Generate test parameters
if
config_files
:
if
config_files
:
metafunc
.
parametrize
(
metafunc
.
parametrize
(
[
"config_filename"
,
"tp_size"
],
"config_filename"
,
[(
config_file
,
int
(
tp_size
))
for
config_file
in
config_files
]
,
config_files
,
ids
=
[
f
"
{
config_file
.
stem
}
-tp
{
tp_size
}
"
for
config_file
in
config_files
],
ids
=
[
config_file
.
stem
for
config_file
in
config_files
],
)
)
else
:
else
:
print
(
"No config files found, test will be skipped"
)
print
(
"No config files found, test will be skipped"
)
tests/evals/gsm8k/test_gsm8k_correctness.py
View file @
10ee1c64
...
@@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script.
...
@@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script.
Replacement for lm-eval-harness with better performance and control.
Replacement for lm-eval-harness with better performance and control.
Usage:
Usage:
pytest -s -v test_gsm8k_correctness.py
\
pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py
\
--config-list-file=configs/models-small.txt
\
--config-list-file=configs/models-small.txt
--tp-size=1
"""
"""
import
shlex
import
yaml
import
yaml
from
tests.utils
import
RemoteOpenAIServer
from
tests.utils
import
RemoteOpenAIServer
from
.gsm8k_eval
import
evaluate_gsm8k
from
.gsm8k_eval
import
evaluate_gsm8k
R
TOL
=
0.08
#
Relativ
e tolerance for accuracy comparison
TOL
=
0.08
#
Absolut
e tolerance for accuracy comparison
def
launch
_gsm8k_eval
(
eval_config
,
server_url
,
tp_size
)
:
def
run
_gsm8k_eval
(
eval_config
:
dict
,
server_url
:
str
)
->
dict
:
"""
Launch
GSM8K evaluation using our isolated script."""
"""
Run
GSM8K evaluation using our isolated script."""
# Extract host and port from server URL
# Extract host and port from server URL
if
"://"
in
server_url
:
if
"://"
in
server_url
:
server_url
=
server_url
.
split
(
"://"
)[
1
]
server_url
=
server_url
.
split
(
"://"
)[
1
]
host_port
=
server_url
.
split
(
"/"
)[
0
]
# Remove path if present
host_port
=
server_url
.
split
(
"/"
)[
0
]
# Remove path if present
if
":"
in
host_port
:
if
":"
in
host_port
:
host
,
p
ort
=
host_port
.
split
(
":"
)
host
,
p
=
host_port
.
split
(
":"
)
port
=
int
(
p
ort
)
port
=
int
(
p
)
else
:
else
:
host
=
host_port
host
=
host_port
port
=
8000
port
=
8000
...
@@ -48,46 +49,57 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size):
...
@@ -48,46 +49,57 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size):
return
results
return
results
def
test_gsm8k_correctness
_param
(
config_filename
,
tp_size
):
def
test_gsm8k_correctness
(
config_filename
):
"""Test GSM8K correctness for a given model configuration."""
"""Test GSM8K correctness for a given model configuration."""
eval_config
=
yaml
.
safe_load
(
config_filename
.
read_text
(
encoding
=
"utf-8"
))
eval_config
=
yaml
.
safe_load
(
config_filename
.
read_text
(
encoding
=
"utf-8"
))
# Server arguments
# Parse server arguments from config (use shlex to handle quoted strings)
server_args
=
[
server_args_str
=
eval_config
.
get
(
"server_args"
,
""
)
"--max-model-len"
,
server_args
=
shlex
.
split
(
server_args_str
)
if
server_args_str
else
[]
str
(
eval_config
.
get
(
"max_model_len"
,
4096
)),
"--enforce-eager"
,
# Add standard server arguments
"--trust-remote-code"
,
server_args
.
extend
(
"--tensor-parallel-size"
,
[
str
(
tp_size
),
"--trust-remote-code"
,
]
]
)
env_dict
=
eval_config
.
get
(
"env"
,
None
)
env_dict
=
eval_config
.
get
(
"env"
,
None
)
print
(
f
"Starting GSM8K evaluation for model:
{
eval_config
[
'model_name'
]
}
"
)
print
(
f
"Expected metric threshold:
{
eval_config
[
'accuracy_threshold'
]
}
"
)
print
(
f
"Number of questions:
{
eval_config
[
'num_questions'
]
}
"
)
print
(
f
"Number of few-shot examples:
{
eval_config
[
'num_fewshot'
]
}
"
)
print
(
f
"Server args:
{
' '
.
join
(
server_args
)
}
"
)
# Launch server and run evaluation
# Launch server and run evaluation
with
RemoteOpenAIServer
(
with
RemoteOpenAIServer
(
eval_config
[
"model_name"
],
server_args
,
env_dict
=
env_dict
,
max_wait_seconds
=
480
eval_config
[
"model_name"
],
server_args
,
env_dict
=
env_dict
,
max_wait_seconds
=
600
,
)
as
remote_server
:
)
as
remote_server
:
server_url
=
remote_server
.
url_for
(
"v1"
)
server_url
=
remote_server
.
url_for
(
"v1"
)
print
(
f
"Server started at:
{
server_url
}
"
)
results
=
launch
_gsm8k_eval
(
eval_config
,
server_url
,
tp_size
)
results
=
run
_gsm8k_eval
(
eval_config
,
server_url
)
# Check accuracy against threshold
measured_metric
=
results
[
"accuracy"
]
measured_accuracy
=
results
[
"accuracy"
]
expected_metric
=
eval_config
[
"accuracy_threshold"
]
expected_accuracy
=
eval_config
[
"accuracy_threshold"
]
print
(
f
"GSM8K Results for
{
eval_config
[
'model_name'
]
}
:"
)
print
(
f
"GSM8K Results for
{
eval_config
[
'model_name'
]
}
:"
)
print
(
f
" Accuracy:
{
measured_accuracy
:.
3
f
}
"
)
print
(
f
" Measured metric:
{
measured_metric
:.
4
f
}
"
)
print
(
f
" Expected:
{
expected_accuracy
:.
3
f
}
"
)
print
(
f
" Expected metric:
{
expected_metric
:.
4
f
}
"
)
print
(
f
" Tolerance:
{
TOL
:.
4
f
}
"
)
print
(
f
" Questions:
{
results
[
'num_questions'
]
}
"
)
print
(
f
" Questions:
{
results
[
'num_questions'
]
}
"
)
print
(
f
" Invalid rate:
{
results
[
'invalid_rate'
]:.
3
f
}
"
)
print
(
f
" Invalid rate:
{
results
[
'invalid_rate'
]:.
3
f
}
"
)
print
(
f
" Latency:
{
results
[
'latency'
]:.
1
f
}
s"
)
print
(
f
" Latency:
{
results
[
'latency'
]:.
1
f
}
s"
)
print
(
f
" QPS:
{
results
[
'questions_per_second'
]:.
1
f
}
"
)
print
(
f
" QPS:
{
results
[
'questions_per_second'
]:.
1
f
}
"
)
# Verify
accuracy
is within tolerance
# Verify
metric
is within tolerance
assert
measured_
accuracy
>=
expected_
accuracy
-
R
TOL
,
(
assert
measured_
metric
>=
expected_
metric
-
TOL
,
(
f
"
Accuracy
too low:
{
measured_
accuracy
:.
3
f
}
< "
f
"
GSM8K metric
too low:
{
measured_
metric
:.
4
f
}
< "
f
"
{
expected_
accuracy
:.
3
f
}
-
{
R
TOL
:.
3
f
}
"
f
"
{
expected_
metric
:.
4
f
}
-
{
TOL
:.
4
f
}
=
{
expected_metric
-
TOL
:.
4
f
}
"
)
)
print
(
f
"✅ GSM8K test passed for
{
eval_config
[
'model_name'
]
}
"
)
print
(
f
"✅ GSM8K test passed for
{
eval_config
[
'model_name'
]
}
"
)
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
10ee1c64
...
@@ -626,17 +626,11 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
...
@@ -626,17 +626,11 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
)
else
:
else
:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
assert
layer
.
expert_map
is
None
,
(
"Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoEMethod."
)
assert
self
.
moe_quant_config
is
not
None
assert
self
.
moe_quant_config
is
not
None
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return
cutlass_moe_fp4
(
return
cutlass_moe_fp4
(
a
=
x
,
a
=
x
,
w1_fp4
=
layer
.
w13_weight
,
w1_fp4
=
layer
.
w13_weight
,
...
@@ -644,6 +638,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
...
@@ -644,6 +638,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
quant_config
=
self
.
moe_quant_config
,
quant_config
=
self
.
moe_quant_config
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
# TODO(bnell): derive these from arguments
# TODO(bnell): derive these from arguments
m
=
x
.
shape
[
0
],
m
=
x
.
shape
[
0
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment