Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ee59a7c6
Unverified
Commit
ee59a7c6
authored
Feb 25, 2026
by
Benjamin Chislett
Committed by
GitHub
Feb 25, 2026
Browse files
[Tests] Add GSM8k check to SpecDec E2E tests (#34772)
Signed-off-by:
Benjamin Chislett
<
bchislett@nvidia.com
>
parent
709eadbb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
237 additions
and
143 deletions
+237
-143
tests/evals/gsm8k/gsm8k_eval.py
tests/evals/gsm8k/gsm8k_eval.py
+95
-43
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+142
-100
No files found.
tests/evals/gsm8k/gsm8k_eval.py
View file @
ee59a7c6
...
@@ -110,29 +110,16 @@ async def call_vllm_api(
...
@@ -110,29 +110,16 @@ async def call_vllm_api(
return
""
,
0
return
""
,
0
def
evaluate_gsm8k
(
def
_build_gsm8k_prompts
(
num_questions
:
int
=
1319
,
num_questions
:
int
=
1319
,
num_shots
:
int
=
5
,
num_shots
:
int
=
5
,
max_tokens
:
int
=
256
,
)
->
tuple
[
list
[
str
],
list
[
int
]]:
host
:
str
=
"http://127.0.0.1"
,
"""Build few-shot GSM8K completion prompts and ground-truth labels."""
port
:
int
=
8000
,
if
num_questions
==
0
:
temperature
:
float
=
0.0
,
return
[],
[]
seed
:
int
|
None
=
42
,
)
->
dict
[
str
,
float
|
int
]:
"""
Evaluate GSM8K accuracy using vLLM serve endpoint.
Returns dict with accuracy, invalid_rate, latency, etc.
"""
base_url
=
f
"
{
host
}
:
{
port
}
"
# Load GSM8K train and test data
train_data
,
test_data
=
load_gsm8k_data
()
train_data
,
test_data
=
load_gsm8k_data
()
# Limit to available test questions
num_questions
=
min
(
num_questions
,
len
(
test_data
))
num_questions
=
min
(
num_questions
,
len
(
test_data
))
# Build few-shot examples from train split (like lm-eval does)
few_shot_examples
=
""
few_shot_examples
=
""
for
i
in
range
(
num_shots
):
for
i
in
range
(
num_shots
):
few_shot_examples
+=
(
few_shot_examples
+=
(
...
@@ -140,25 +127,74 @@ def evaluate_gsm8k(
...
@@ -140,25 +127,74 @@ def evaluate_gsm8k(
f
"Answer:
{
train_data
[
i
][
'answer'
]
}
\n\n
"
f
"Answer:
{
train_data
[
i
][
'answer'
]
}
\n\n
"
)
)
# Prepare test questions and labels from test split
prompts
=
[]
questions
=
[]
labels
=
[]
labels
=
[]
for
i
in
range
(
num_questions
):
for
i
in
range
(
num_questions
):
questions
.
append
(
f
"Question:
{
test_data
[
i
][
'question'
]
}
\n
Answer:"
)
prompts
.
append
(
few_shot_examples
+
f
"Question:
{
test_data
[
i
][
'question'
]
}
\n
Answer:"
)
labels
.
append
(
get_answer_value
(
test_data
[
i
][
"answer"
]))
labels
.
append
(
get_answer_value
(
test_data
[
i
][
"answer"
]))
assert
all
(
label
!=
INVALID
for
label
in
labels
),
"Some labels are invalid"
assert
all
(
label
!=
INVALID
for
label
in
labels
),
"Some labels are invalid"
return
prompts
,
labels
def
_score_gsm8k
(
states
:
list
[
str
],
output_tokens
:
list
[
int
],
labels
:
list
[
int
],
num_shots
:
int
,
max_tokens
:
int
,
latency
:
float
,
)
->
dict
[
str
,
float
|
int
]:
"""Score GSM8K responses and return a results dict."""
num_questions
=
len
(
labels
)
preds
=
[
get_answer_value
(
state
)
for
state
in
states
]
accuracy
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid_rate
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
total_output_tokens
=
sum
(
output_tokens
)
tokens_per_second
=
total_output_tokens
/
latency
if
latency
>
0
else
0.0
return
{
"accuracy"
:
accuracy
,
"invalid_rate"
:
invalid_rate
,
"latency"
:
latency
,
"questions_per_second"
:
num_questions
/
latency
if
latency
>
0
else
0.0
,
"total_output_tokens"
:
total_output_tokens
,
"tokens_per_second"
:
tokens_per_second
,
"num_questions"
:
num_questions
,
"num_shots"
:
num_shots
,
"max_tokens"
:
max_tokens
,
"timestamp"
:
time
.
time
(),
}
def
evaluate_gsm8k
(
num_questions
:
int
=
1319
,
num_shots
:
int
=
5
,
max_tokens
:
int
=
256
,
host
:
str
=
"http://127.0.0.1"
,
port
:
int
=
8000
,
temperature
:
float
=
0.0
,
seed
:
int
|
None
=
42
,
)
->
dict
[
str
,
float
|
int
]:
"""
Evaluate GSM8K accuracy using vLLM serve endpoint.
Returns dict with accuracy, invalid_rate, latency, etc.
"""
base_url
=
f
"
{
host
}
:
{
port
}
"
prompts
,
labels
=
_build_gsm8k_prompts
(
num_questions
,
num_shots
)
num_questions
=
len
(
prompts
)
# Run evaluation
async
def
run_async_evaluation
():
async
def
run_async_evaluation
():
states
:
list
[
str
]
=
[
""
]
*
num_questions
states
:
list
[
str
]
=
[
""
]
*
num_questions
output_tokens
:
list
[
int
]
=
[
0
]
*
num_questions
output_tokens
:
list
[
int
]
=
[
0
]
*
num_questions
async
def
get_answer
(
session
:
aiohttp
.
ClientSession
,
i
:
int
)
->
tuple
[
str
,
int
]:
async
def
get_answer
(
session
:
aiohttp
.
ClientSession
,
i
:
int
)
->
tuple
[
str
,
int
]:
prompt
=
few_shot_examples
+
questions
[
i
]
answer
,
tokens
=
await
call_vllm_api
(
answer
,
tokens
=
await
call_vllm_api
(
session
=
session
,
session
=
session
,
prompt
=
prompt
,
prompt
=
prompt
s
[
i
]
,
temperature
=
temperature
,
temperature
=
temperature
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
],
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
],
...
@@ -183,27 +219,43 @@ def evaluate_gsm8k(
...
@@ -183,27 +219,43 @@ def evaluate_gsm8k(
states
,
output_tokens
=
asyncio
.
run
(
run_async_evaluation
())
states
,
output_tokens
=
asyncio
.
run
(
run_async_evaluation
())
latency
=
time
.
perf_counter
()
-
tic
latency
=
time
.
perf_counter
()
-
tic
# Compute metrics
return
_score_gsm8k
(
states
,
output_tokens
,
labels
,
num_shots
,
max_tokens
,
latency
)
preds
=
[
get_answer_value
(
state
)
for
state
in
states
]
accuracy
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid_rate
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
total_output_tokens
=
sum
(
output_tokens
)
tokens_per_second
=
total_output_tokens
/
latency
if
latency
>
0
else
0.0
result
=
{
"accuracy"
:
accuracy
,
"invalid_rate"
:
invalid_rate
,
"latency"
:
latency
,
"questions_per_second"
:
num_questions
/
latency
,
"total_output_tokens"
:
total_output_tokens
,
"tokens_per_second"
:
tokens_per_second
,
"num_questions"
:
num_questions
,
"num_shots"
:
num_shots
,
"max_tokens"
:
max_tokens
,
"timestamp"
:
time
.
time
(),
}
return
result
def
evaluate_gsm8k_offline
(
llm
,
num_questions
:
int
=
1319
,
num_shots
:
int
=
5
,
max_tokens
:
int
=
256
,
temperature
:
float
=
0.0
,
)
->
dict
[
str
,
float
|
int
]:
"""Evaluate GSM8K accuracy using an offline vllm.LLM object.
Same prompts and scoring as evaluate_gsm8k(), but runs generation
directly via llm.generate() instead of calling a server over HTTP.
"""
from
vllm
import
SamplingParams
prompts
,
labels
=
_build_gsm8k_prompts
(
num_questions
,
num_shots
)
sampling_params
=
SamplingParams
(
temperature
=
temperature
,
max_tokens
=
max_tokens
,
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
],
)
print
(
f
"Running offline GSM8K evaluation:
{
len
(
prompts
)
}
questions,
{
num_shots
}
-shot"
)
tic
=
time
.
perf_counter
()
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
latency
=
time
.
perf_counter
()
-
tic
states
=
[
o
.
outputs
[
0
].
text
for
o
in
outputs
]
output_tokens
=
[
len
(
o
.
outputs
[
0
].
token_ids
)
for
o
in
outputs
]
return
_score_gsm8k
(
states
,
output_tokens
,
labels
,
num_shots
,
max_tokens
,
latency
)
def
main
()
->
None
:
def
main
()
->
None
:
...
...
tests/v1/e2e/test_spec_decode.py
View file @
ee59a7c6
...
@@ -8,6 +8,7 @@ from typing import Any
...
@@ -8,6 +8,7 @@ from typing import Any
import
pytest
import
pytest
import
torch
import
torch
from
tests.evals.gsm8k.gsm8k_eval
import
_build_gsm8k_prompts
,
evaluate_gsm8k_offline
from
tests.utils
import
get_attn_backend_list_based_on_platform
,
large_gpu_mark
from
tests.utils
import
get_attn_backend_list_based_on_platform
,
large_gpu_mark
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.base
import
VLLM_S3_BUCKET_URL
from
vllm.assets.base
import
VLLM_S3_BUCKET_URL
...
@@ -35,53 +36,57 @@ def _skip_if_insufficient_gpus_for_tp(tp_size: int):
...
@@ -35,53 +36,57 @@ def _skip_if_insufficient_gpus_for_tp(tp_size: int):
Messages
=
list
[
dict
[
str
,
Any
]]
Messages
=
list
[
dict
[
str
,
Any
]]
def
get_test_prompts
(
def
get_test_prompts
(
mm_enabled
:
bool
,
num_prompts
:
int
=
100
)
->
list
[
Messages
]:
mm_enabled
:
bool
,
quiet
:
bool
=
False
,
num_prompts
:
int
=
100
prompt_types
=
[
"repeat"
,
"gsm8k"
]
)
->
list
[
Messages
]:
prompt_types
=
[
"repeat"
,
"sentence"
]
if
mm_enabled
:
if
mm_enabled
:
prompt_types
.
append
(
"mm"
)
prompt_types
.
append
(
"mm"
)
prompts
=
[]
prompts
:
list
[
Messages
]
=
[]
random
.
seed
(
0
)
num_repeat_prompts
=
num_prompts
//
len
(
prompt_types
)
random_prompt_type_choices
=
random
.
choices
(
prompt_types
,
k
=
num_prompts
)
if
mm_enabled
:
num_gsm8k_prompts
=
num_prompts
//
len
(
prompt_types
)
if
not
quiet
:
num_mm_prompts
=
num_prompts
-
num_repeat_prompts
-
num_gsm8k_prompts
print
(
f
"Prompt types:
{
random_prompt_type_choices
}
"
)
else
:
num_mm_prompts
=
0
num_gsm8k_prompts
=
num_prompts
-
num_repeat_prompts
# Generate a mixed batch of prompts, some of which can be easily
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
# predicted by n-gram matching and some which likely cannot.
for
kind
in
random_prompt_type_choices
:
random
.
seed
(
0
)
for
_
in
range
(
num_repeat_prompts
):
word_choices
=
[
"test"
,
"temp"
,
"hello"
,
"where"
]
word_choices
=
[
"test"
,
"temp"
,
"hello"
,
"where"
]
word
=
random
.
choice
(
word_choices
)
word
=
random
.
choice
(
word_choices
)
prompt
:
str
|
list
[
dict
[
str
,
Any
]]
=
""
prompts
.
append
(
if
kind
==
"repeat"
:
[
prompt
=
f
"""
please repeat the word '
{
word
}
' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
elif
kind
==
"sentence"
:
prompt
=
f
"""
please give a ten-word sentence that
uses the word
{
word
}
at least once.
give no other output than that simple sentence without quotes.
"""
elif
kind
==
"mm"
:
placeholders
=
[
{
{
"type"
:
"image_url"
,
"role"
:
"user"
,
"image_url"
:
{
"content"
:
f
"""
"url"
:
f
"
{
VLLM_S3_BUCKET_URL
}
/
{
VLM_IMAGES_DIR
}
/stop_sign.jpg"
please repeat the word '
{
word
}
' 10 times.
},
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
,
}
}
]
]
prompt
=
[
)
*
placeholders
,
prompts
.
extend
(
{
"type"
:
"text"
,
"text"
:
"The meaning of the image is"
},
[{
"role"
:
"user"
,
"content"
:
prompt
}]
]
for
prompt
in
_build_gsm8k_prompts
(
else
:
num_questions
=
num_gsm8k_prompts
,
num_shots
=
5
raise
ValueError
(
f
"Unknown prompt type:
{
kind
}
"
)
)[
0
]
)
for
_
in
range
(
num_mm_prompts
):
placeholders
=
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
f
"
{
VLLM_S3_BUCKET_URL
}
/
{
VLM_IMAGES_DIR
}
/stop_sign.jpg"
},
}
]
prompt
=
[
*
placeholders
,
{
"type"
:
"text"
,
"text"
:
"The meaning of the image is"
},
]
prompts
.
append
([{
"role"
:
"user"
,
"content"
:
prompt
}])
prompts
.
append
([{
"role"
:
"user"
,
"content"
:
prompt
}])
return
prompts
return
prompts
...
@@ -113,6 +118,25 @@ def model_name():
...
@@ -113,6 +118,25 @@ def model_name():
return
"meta-llama/Llama-3.1-8B-Instruct"
return
"meta-llama/Llama-3.1-8B-Instruct"
def
evaluate_llm_for_gsm8k
(
llm
:
LLM
,
expected_accuracy_threshold
:
float
=
0.70
)
->
None
:
"""Evaluate the LLM on GSM8K and check that accuracy is above a sanity threshold.
The default threshold assumes the LLM uses the same target model as the "model_name"
fixture, with max model len == 4096. Precomputed reference value is 75% to 80%
on GSM8K with greedy decoding, so we check that it's above a sanity threshold of 70%
to verify that the model is correct.
"""
if
expected_accuracy_threshold
<=
0.0
:
print
(
"Skipping GSM8K evaluation"
)
return
results
=
evaluate_gsm8k_offline
(
llm
)
accuracy
=
results
[
"accuracy"
]
print
(
f
"GSM8K accuracy:
{
accuracy
:.
3
f
}
"
)
assert
accuracy
>=
expected_accuracy_threshold
,
(
f
"Expected GSM8K accuracy >=
{
expected_accuracy_threshold
}
, got
{
accuracy
:.
3
f
}
"
)
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_torch_dynamo
():
def
reset_torch_dynamo
():
"""Reset torch dynamo cache before each test"""
"""Reset torch dynamo cache before each test"""
...
@@ -138,41 +162,14 @@ def reset_torch_dynamo():
...
@@ -138,41 +162,14 @@ def reset_torch_dynamo():
)
)
def
test_ngram_and_suffix_correctness
(
def
test_ngram_and_suffix_correctness
(
speculative_config
:
dict
,
speculative_config
:
dict
,
monkeypatch
:
pytest
.
MonkeyPatch
,
sampling_config
:
SamplingParams
,
model_name
:
str
,
model_name
:
str
,
):
):
"""
Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
"""
test_prompts
=
get_test_prompts
(
mm_enabled
=
False
)
ref_llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
del
ref_llm
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
spec_llm
=
LLM
(
spec_llm
=
LLM
(
model
=
model_name
,
model
=
model_name
,
speculative_config
=
speculative_config
,
speculative_config
=
speculative_config
,
max_model_len
=
1024
,
max_model_len
=
4096
,
)
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
evaluate_llm_for_gsm8k
(
spec_llm
)
matches
=
0
misses
=
0
for
ref_output
,
spec_output
in
zip
(
ref_outputs
,
spec_outputs
):
if
ref_output
.
outputs
[
0
].
text
==
spec_output
.
outputs
[
0
].
text
:
matches
+=
1
else
:
misses
+=
1
print
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert
matches
>=
int
(
0.66
*
len
(
ref_outputs
))
del
spec_llm
del
spec_llm
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
...
@@ -238,10 +235,10 @@ def test_suffix_decoding_acceptance(
...
@@ -238,10 +235,10 @@ def test_suffix_decoding_acceptance(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[
"model_path"
,
"expected_accuracy_threshold"
],
[
[
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3"
,
(
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3"
,
0.7
),
# ref: 75%-80%
"RedHatAI/Qwen3-8B-speculator.eagle3"
,
(
"RedHatAI/Qwen3-8B-speculator.eagle3"
,
0.8
),
# ref: 87%-92%
],
],
ids
=
[
"llama3_eagle3_speculator"
,
"qwen3_eagle3_speculator"
],
ids
=
[
"llama3_eagle3_speculator"
,
"qwen3_eagle3_speculator"
],
)
)
...
@@ -249,6 +246,7 @@ def test_speculators_model_integration(
...
@@ -249,6 +246,7 @@ def test_speculators_model_integration(
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
sampling_config
:
SamplingParams
,
sampling_config
:
SamplingParams
,
model_path
:
str
,
model_path
:
str
,
expected_accuracy_threshold
:
float
,
):
):
"""
"""
Test that speculators models work with the simplified integration.
Test that speculators models work with the simplified integration.
...
@@ -262,7 +260,8 @@ def test_speculators_model_integration(
...
@@ -262,7 +260,8 @@ def test_speculators_model_integration(
2. Verifier model is extracted from speculator config
2. Verifier model is extracted from speculator config
3. Speculative decoding is automatically enabled
3. Speculative decoding is automatically enabled
4. Text generation works correctly
4. Text generation works correctly
5. Output matches reference (non-speculative) generation
5. GSM8k accuracy of the model passes a sanity check when speculative decoding on
6. Output matches reference (non-speculative) generation
"""
"""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
...
@@ -270,7 +269,10 @@ def test_speculators_model_integration(
...
@@ -270,7 +269,10 @@ def test_speculators_model_integration(
test_prompts
=
get_test_prompts
(
mm_enabled
=
False
)
test_prompts
=
get_test_prompts
(
mm_enabled
=
False
)
# First run: Direct speculator model (simplified integration)
# First run: Direct speculator model (simplified integration)
spec_llm
=
LLM
(
model
=
model_path
,
max_model_len
=
1024
)
spec_llm
=
LLM
(
model
=
model_path
,
max_model_len
=
4096
)
evaluate_llm_for_gsm8k
(
spec_llm
,
expected_accuracy_threshold
=
expected_accuracy_threshold
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
# Verify speculative config was auto-detected
# Verify speculative config was auto-detected
...
@@ -297,7 +299,7 @@ def test_speculators_model_integration(
...
@@ -297,7 +299,7 @@ def test_speculators_model_integration(
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
# Second run: Reference without speculative decoding
# Second run: Reference without speculative decoding
ref_llm
=
LLM
(
model
=
verifier_model
,
max_model_len
=
1024
)
ref_llm
=
LLM
(
model
=
verifier_model
,
max_model_len
=
4096
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
del
ref_llm
del
ref_llm
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -318,19 +320,27 @@ def test_speculators_model_integration(
...
@@ -318,19 +320,27 @@ def test_speculators_model_integration(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
[
"model_setup"
,
"mm_enabled"
,
"enable_chunked_prefill"
,
"model_impl"
],
[
"model_setup"
,
"mm_enabled"
,
"enable_chunked_prefill"
,
"model_impl"
,
"expected_accuracy_threshold"
,
],
[
[
(
(
(
"eagle3"
,
"Qwen/Qwen3-8B"
,
"AngelSlim/Qwen3-8B_eagle3"
,
1
),
(
"eagle3"
,
"Qwen/Qwen3-8B"
,
"AngelSlim/Qwen3-8B_eagle3"
,
1
),
False
,
False
,
False
,
False
,
"auto"
,
"auto"
,
0.8
,
# ref: 90%
),
),
(
(
(
"eagle3"
,
"Qwen/Qwen3-8B"
,
"AngelSlim/Qwen3-8B_eagle3"
,
1
),
(
"eagle3"
,
"Qwen/Qwen3-8B"
,
"AngelSlim/Qwen3-8B_eagle3"
,
1
),
False
,
False
,
False
,
False
,
"transformers"
,
"transformers"
,
0.8
,
# ref: 90%
),
),
pytest
.
param
(
pytest
.
param
(
(
(
...
@@ -342,6 +352,7 @@ def test_speculators_model_integration(
...
@@ -342,6 +352,7 @@ def test_speculators_model_integration(
False
,
False
,
False
,
False
,
"auto"
,
"auto"
,
0.8
,
# ref: 90%
marks
=
pytest
.
mark
.
skip
(
marks
=
pytest
.
mark
.
skip
(
reason
=
"architecture of its eagle3 is LlamaForCausalLMEagle3"
reason
=
"architecture of its eagle3 is LlamaForCausalLMEagle3"
),
),
...
@@ -356,6 +367,7 @@ def test_speculators_model_integration(
...
@@ -356,6 +367,7 @@ def test_speculators_model_integration(
False
,
False
,
False
,
False
,
"auto"
,
"auto"
,
0.7
,
# TODO, update this with a reference value when re-enabling this case
marks
=
pytest
.
mark
.
skip
(
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to its head_dim not being a a multiple of 32"
reason
=
"Skipping due to its head_dim not being a a multiple of 32"
),
),
...
@@ -370,6 +382,7 @@ def test_speculators_model_integration(
...
@@ -370,6 +382,7 @@ def test_speculators_model_integration(
False
,
False
,
True
,
True
,
"auto"
,
"auto"
,
0.7
,
# ref: 75%-80%
marks
=
large_gpu_mark
(
min_gb
=
40
),
marks
=
large_gpu_mark
(
min_gb
=
40
),
),
# works on 4x H100
),
# works on 4x H100
(
(
...
@@ -382,6 +395,7 @@ def test_speculators_model_integration(
...
@@ -382,6 +395,7 @@ def test_speculators_model_integration(
False
,
False
,
False
,
False
,
"auto"
,
"auto"
,
0.7
,
# ref: 75%-80%
),
),
pytest
.
param
(
pytest
.
param
(
(
(
...
@@ -393,7 +407,8 @@ def test_speculators_model_integration(
...
@@ -393,7 +407,8 @@ def test_speculators_model_integration(
False
,
False
,
False
,
False
,
"auto"
,
"auto"
,
marks
=
large_gpu_mark
(
min_gb
=
80
),
0.8
,
# ref: 90%
# marks=large_gpu_mark(min_gb=80),
),
# works on 4x H100
),
# works on 4x H100
pytest
.
param
(
pytest
.
param
(
(
(
...
@@ -405,6 +420,7 @@ def test_speculators_model_integration(
...
@@ -405,6 +420,7 @@ def test_speculators_model_integration(
True
,
True
,
True
,
True
,
"auto"
,
"auto"
,
0.8
,
# ref: 90%
marks
=
large_gpu_mark
(
min_gb
=
80
),
marks
=
large_gpu_mark
(
min_gb
=
80
),
),
# works on 4x H100
),
# works on 4x H100
(
(
...
@@ -417,6 +433,7 @@ def test_speculators_model_integration(
...
@@ -417,6 +433,7 @@ def test_speculators_model_integration(
False
,
False
,
False
,
False
,
"auto"
,
"auto"
,
0.0
,
# dummy model, skip gsm8k check
),
),
],
],
ids
=
[
ids
=
[
...
@@ -437,10 +454,18 @@ def test_eagle_correctness(
...
@@ -437,10 +454,18 @@ def test_eagle_correctness(
sampling_config
:
SamplingParams
,
sampling_config
:
SamplingParams
,
model_setup
:
tuple
[
str
,
str
,
str
,
int
],
model_setup
:
tuple
[
str
,
str
,
str
,
int
],
mm_enabled
:
bool
,
mm_enabled
:
bool
,
expected_accuracy_threshold
:
float
,
enable_chunked_prefill
:
bool
,
enable_chunked_prefill
:
bool
,
model_impl
:
str
,
model_impl
:
str
,
attn_backend
:
str
,
attn_backend
:
str
,
):
):
"""
Compare the outputs of a original LLM and a speculative LLM
which should be the same when using eagle speculative decoding. Due to some variance
in the engine, it is possible for some outputs to differ, so we expect that at least
6/10 output tokens match exactly, and that the GSM8k accuracy is above
a precomputed reference threshold for each model.
"""
if
attn_backend
==
"TREE_ATTN"
:
if
attn_backend
==
"TREE_ATTN"
:
# TODO: Fix this flaky test
# TODO: Fix this flaky test
pytest
.
skip
(
pytest
.
skip
(
...
@@ -461,11 +486,6 @@ def test_eagle_correctness(
...
@@ -461,11 +486,6 @@ def test_eagle_correctness(
# Generate test prompts inside the function instead of using fixture
# Generate test prompts inside the function instead of using fixture
test_prompts
=
get_test_prompts
(
mm_enabled
)
test_prompts
=
get_test_prompts
(
mm_enabled
)
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
"""
# Determine attention config
# Determine attention config
# Scout requires default backend selection because vision encoder has
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
...
@@ -505,6 +525,9 @@ def test_eagle_correctness(
...
@@ -505,6 +525,9 @@ def test_eagle_correctness(
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
attention_config
=
attention_config
,
attention_config
=
attention_config
,
)
)
evaluate_llm_for_gsm8k
(
ref_llm
,
expected_accuracy_threshold
=
expected_accuracy_threshold
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
del
ref_llm
del
ref_llm
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -526,6 +549,9 @@ def test_eagle_correctness(
...
@@ -526,6 +549,9 @@ def test_eagle_correctness(
model_impl
=
model_impl
,
model_impl
=
model_impl
,
attention_config
=
attention_config
,
attention_config
=
attention_config
,
)
)
evaluate_llm_for_gsm8k
(
spec_llm
,
expected_accuracy_threshold
=
expected_accuracy_threshold
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
matches
=
0
matches
=
0
misses
=
0
misses
=
0
...
@@ -546,10 +572,10 @@ def test_eagle_correctness(
...
@@ -546,10 +572,10 @@ def test_eagle_correctness(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
[
"model_setup"
,
"mm_enabled"
],
[
"model_setup"
,
"mm_enabled"
,
"expected_accuracy_threshold"
],
[
[
((
"mtp"
,
"XiaomiMiMo/MiMo-7B-Base"
,
1
),
False
)
,
((
"mtp"
,
"XiaomiMiMo/MiMo-7B-Base"
,
1
),
False
,
0.5
),
# ref: 65%-70%
((
"mtp"
,
"ZixiQi/DeepSeek-V3-4layers-MTP-FP8"
,
1
),
False
)
,
((
"mtp"
,
"ZixiQi/DeepSeek-V3-4layers-MTP-FP8"
,
1
),
False
,
0.0
),
# dummy model
],
],
ids
=
[
"mimo"
,
"deepseek"
],
ids
=
[
"mimo"
,
"deepseek"
],
)
)
...
@@ -558,14 +584,17 @@ def test_mtp_correctness(
...
@@ -558,14 +584,17 @@ def test_mtp_correctness(
sampling_config
:
SamplingParams
,
sampling_config
:
SamplingParams
,
model_setup
:
tuple
[
str
,
str
,
int
],
model_setup
:
tuple
[
str
,
str
,
int
],
mm_enabled
:
bool
,
mm_enabled
:
bool
,
expected_accuracy_threshold
:
float
,
):
):
# Generate test prompts inside the function instead of using fixture
test_prompts
=
get_test_prompts
(
mm_enabled
)
"""
"""
Compare the outputs of a original LLM and a speculative LLM
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
which should be the same when using MTP speculative decoding. Due to some variance
model_setup: (method, model_name, tp_size)
in the engine, it is possible for some outputs to differ, so we expect that at least
6/10 output tokens match exactly, and that the GSM8k accuracy is above a precomputed
reference threshold for each model.
"""
"""
# Generate test prompts inside the function instead of using fixture
test_prompts
=
get_test_prompts
(
mm_enabled
)
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
...
@@ -579,6 +608,9 @@ def test_mtp_correctness(
...
@@ -579,6 +608,9 @@ def test_mtp_correctness(
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
evaluate_llm_for_gsm8k
(
ref_llm
,
expected_accuracy_threshold
=
expected_accuracy_threshold
)
del
ref_llm
del
ref_llm
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
...
@@ -594,6 +626,9 @@ def test_mtp_correctness(
...
@@ -594,6 +626,9 @@ def test_mtp_correctness(
},
},
max_model_len
=
2048
,
max_model_len
=
2048
,
)
)
evaluate_llm_for_gsm8k
(
spec_llm
,
expected_accuracy_threshold
=
expected_accuracy_threshold
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
matches
=
0
matches
=
0
misses
=
0
misses
=
0
...
@@ -621,12 +656,13 @@ class ArgsTest:
...
@@ -621,12 +656,13 @@ class ArgsTest:
num_speculative_tokens
:
int
num_speculative_tokens
:
int
expected_acceptance_rate
:
float
expected_acceptance_rate
:
float
expected_acceptance_len
:
float
expected_acceptance_len
:
float
expected_gsm8k_accuracy
:
float
=
0.0
# skip by default
# Defaults
# Defaults
enforce_eager
:
bool
=
True
enforce_eager
:
bool
=
True
parallel_drafting
:
bool
=
False
parallel_drafting
:
bool
=
False
target_tensor_parallel_size
:
int
=
1
target_tensor_parallel_size
:
int
=
1
draft_tensor_parallel_size
:
int
=
1
draft_tensor_parallel_size
:
int
=
1
max_model_len
:
int
=
1024
max_model_len
:
int
=
2048
gpu_memory_utilization
:
float
=
0.5
gpu_memory_utilization
:
float
=
0.5
dataset
:
str
=
"test_prompts"
dataset
:
str
=
"test_prompts"
num_prompts
:
int
=
100
num_prompts
:
int
=
100
...
@@ -639,8 +675,9 @@ cases = [
...
@@ -639,8 +675,9 @@ cases = [
draft_model
=
"Qwen/Qwen3-0.6B"
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
sampling_config
=
greedy_sampling
(),
sampling_config
=
greedy_sampling
(),
num_speculative_tokens
=
3
,
# K
num_speculative_tokens
=
3
,
# K
expected_acceptance_len
=
3
+
1
,
# K + 1
expected_acceptance_len
=
0.98
*
(
3
+
1
),
# epsilon discount of K + 1
expected_acceptance_rate
=
1.0
,
expected_acceptance_rate
=
0.98
,
# slight epsilon
expected_gsm8k_accuracy
=
0.25
,
# ref: 35-40%
),
),
# Smaller draft model, stochastic sampling.
# Smaller draft model, stochastic sampling.
ArgsTest
(
ArgsTest
(
...
@@ -648,8 +685,9 @@ cases = [
...
@@ -648,8 +685,9 @@ cases = [
draft_model
=
"Qwen/Qwen3-0.6B"
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
sampling_config
=
stochastic_sampling
(),
sampling_config
=
stochastic_sampling
(),
num_speculative_tokens
=
3
,
num_speculative_tokens
=
3
,
expected_acceptance_len
=
2.8
+
1
,
expected_acceptance_len
=
3.4
,
# ref: 3.7
expected_acceptance_rate
=
0.9
,
expected_acceptance_rate
=
0.80
,
# ref: 0.90
expected_gsm8k_accuracy
=
0.5
,
# ref: 60%. Note gsm8k always runs greedy sampling
),
),
]
]
...
@@ -669,9 +707,8 @@ def test_draft_model_realistic_example():
...
@@ -669,9 +707,8 @@ def test_draft_model_realistic_example():
num_speculative_tokens
=
3
,
num_speculative_tokens
=
3
,
sampling_config
=
greedy_sampling
(),
sampling_config
=
greedy_sampling
(),
enforce_eager
=
False
,
enforce_eager
=
False
,
# values below are not derived, but just prevent a regression
expected_acceptance_len
=
2.6
,
# ref: 2.86
expected_acceptance_len
=
2.8
,
expected_acceptance_rate
=
0.5
,
# ref: 0.62
expected_acceptance_rate
=
0.55
,
)
)
assert_draft_model_correctness
(
args
)
assert_draft_model_correctness
(
args
)
...
@@ -685,9 +722,8 @@ def test_draft_model_parallel_drafting():
...
@@ -685,9 +722,8 @@ def test_draft_model_parallel_drafting():
sampling_config
=
greedy_sampling
(),
sampling_config
=
greedy_sampling
(),
parallel_drafting
=
True
,
parallel_drafting
=
True
,
enforce_eager
=
False
,
enforce_eager
=
False
,
# values below are collected from a stable run, with ~5% tolerance
expected_acceptance_len
=
2.3
,
# ref: 2.52
expected_acceptance_len
=
2.375
,
expected_acceptance_rate
=
0.4
,
# ref: 0.51
expected_acceptance_rate
=
0.45
,
)
)
assert_draft_model_correctness
(
args
)
assert_draft_model_correctness
(
args
)
...
@@ -723,6 +759,7 @@ def test_draft_model_tensor_parallelism():
...
@@ -723,6 +759,7 @@ def test_draft_model_tensor_parallelism():
draft_tensor_parallel_size
=
2
,
draft_tensor_parallel_size
=
2
,
**
some_high_acceptance_metrics
(),
**
some_high_acceptance_metrics
(),
enforce_eager
=
False
,
enforce_eager
=
False
,
expected_gsm8k_accuracy
=
0.5
,
)
)
assert_draft_model_correctness
(
sd_case
)
assert_draft_model_correctness
(
sd_case
)
...
@@ -797,9 +834,14 @@ def assert_draft_model_correctness(args: ArgsTest):
...
@@ -797,9 +834,14 @@ def assert_draft_model_correctness(args: ArgsTest):
# we don't check the outputs, only check the metrics
# we don't check the outputs, only check the metrics
spec_llm
.
chat
(
test_prompts
,
args
.
sampling_config
)
spec_llm
.
chat
(
test_prompts
,
args
.
sampling_config
)
metrics
=
spec_llm
.
get_metrics
()
metrics
=
spec_llm
.
get_metrics
()
acceptance_rate
:
float
=
compute_acceptance_rate
(
metrics
)
acceptance_rate
:
float
=
compute_acceptance_rate
(
metrics
)
acceptance_len
:
float
=
compute_acceptance_len
(
metrics
)
acceptance_len
:
float
=
compute_acceptance_len
(
metrics
)
# Need to evaluate after getting metrics to avoid polluting the AR
evaluate_llm_for_gsm8k
(
spec_llm
,
expected_accuracy_threshold
=
args
.
expected_gsm8k_accuracy
)
del
spec_llm
# CLEANUP
del
spec_llm
# CLEANUP
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
...
@@ -817,7 +859,7 @@ def assert_draft_model_correctness(args: ArgsTest):
...
@@ -817,7 +859,7 @@ def assert_draft_model_correctness(args: ArgsTest):
def
get_messages
(
dataset
:
str
,
n
:
int
)
->
list
[
Messages
]:
def
get_messages
(
dataset
:
str
,
n
:
int
)
->
list
[
Messages
]:
if
dataset
==
"test_prompts"
:
if
dataset
==
"test_prompts"
:
return
get_test_prompts
(
mm_enabled
=
False
,
quiet
=
True
,
num_prompts
=
n
)
return
get_test_prompts
(
mm_enabled
=
False
,
num_prompts
=
n
)
elif
dataset
==
"likaixin/InstructCoder"
:
elif
dataset
==
"likaixin/InstructCoder"
:
return
get_instruct_coder_messages
(
n
=
n
)
return
get_instruct_coder_messages
(
n
=
n
)
else
:
else
:
...
@@ -828,8 +870,8 @@ def some_high_acceptance_metrics() -> dict:
...
@@ -828,8 +870,8 @@ def some_high_acceptance_metrics() -> dict:
return
{
return
{
"sampling_config"
:
greedy_sampling
(),
"sampling_config"
:
greedy_sampling
(),
"num_speculative_tokens"
:
3
,
"num_speculative_tokens"
:
3
,
"expected_acceptance_len"
:
2.8
+
1
,
"expected_acceptance_len"
:
3.4
,
# ref: 3.75
"expected_acceptance_rate"
:
0.
90
,
"expected_acceptance_rate"
:
0.
8
,
# ref: 0.9
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment