Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6f5c0931
Unverified
Commit
6f5c0931
authored
Sep 27, 2025
by
Jonas M. Kübler
Committed by
GitHub
Sep 27, 2025
Browse files
[Spec decode] automatically disable mm for text-only draft models (#25667)
Signed-off-by:
Jonas Kuebler
<
kuebj@amazon.com
>
parent
4e33a7ea
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
67 deletions
+83
-67
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+69
-67
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+14
-0
No files found.
tests/v1/e2e/test_spec_decode.py
View file @
6f5c0931
...
@@ -8,7 +8,7 @@ from typing import Any, Union
...
@@ -8,7 +8,7 @@ from typing import Any, Union
import
pytest
import
pytest
import
torch
import
torch
from
tests.utils
import
get_attn_backend_list_based_on_platform
from
tests.utils
import
get_attn_backend_list_based_on_platform
,
large_gpu_mark
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.base
import
VLLM_S3_BUCKET_URL
from
vllm.assets.base
import
VLLM_S3_BUCKET_URL
from
vllm.assets.image
import
VLM_IMAGES_DIR
from
vllm.assets.image
import
VLM_IMAGES_DIR
...
@@ -88,69 +88,66 @@ def test_ngram_correctness(
...
@@ -88,69 +88,66 @@ def test_ngram_correctness(
Compare the outputs of an original LLM and a speculative LLM
Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
should be the same when using ngram speculative decoding.
'''
'''
with
monkeypatch
.
context
()
as
m
:
test_prompts
=
get_test_prompts
(
mm_enabled
=
False
)
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
test_prompts
=
get_test_prompts
(
mm_enabled
=
False
)
ref_llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
ref_llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
del
ref_llm
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
torch
.
cuda
.
empty_cache
()
del
ref_llm
cleanup_dist_env_and_memory
()
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
spec_llm
=
LLM
(
model
=
model_name
,
spec_llm
=
LLM
(
speculative_config
=
{
model
=
model_name
,
"method"
:
"ngram"
,
speculative_config
=
{
"prompt_lookup_max"
:
5
,
"method"
:
"ngram"
,
"prompt_lookup_min"
:
3
,
"prompt_lookup_max"
:
5
,
"num_speculative_tokens"
:
3
,
"prompt_lookup_min"
:
3
,
},
"num_speculative_tokens"
:
3
,
max_model_len
=
1024
,
},
)
max_model_len
=
1024
,
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
)
matches
=
0
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
misses
=
0
matches
=
0
for
ref_output
,
spec_output
in
zip
(
ref_outputs
,
spec_outputs
):
misses
=
0
if
ref_output
.
outputs
[
0
].
text
==
spec_output
.
outputs
[
0
].
text
:
for
ref_output
,
spec_output
in
zip
(
ref_outputs
,
spec_outputs
):
matches
+=
1
if
ref_output
.
outputs
[
0
].
text
==
spec_output
.
outputs
[
0
].
text
:
else
:
matches
+=
1
misses
+=
1
else
:
print
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
"
)
misses
+=
1
print
(
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
# Heuristic: expect at least 66% of the prompts to match exactly
assert
matches
>=
int
(
0.66
*
len
(
ref_outputs
))
# Upon failure, inspect the outputs to check for inaccuracy.
del
spec_llm
assert
matches
>=
int
(
0.66
*
len
(
ref_outputs
))
torch
.
cuda
.
empty_cache
()
del
spec_llm
cleanup_dist_env_and_memory
()
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
[
"model_setup"
,
"mm_enabled"
],
@
pytest
.
mark
.
parametrize
([
"model_setup"
,
"mm_enabled"
],
[
[
((
"eagle3"
,
"Qwen/Qwen3-8B"
,
"AngelSlim/Qwen3-8B_eagle3"
,
1
),
False
),
((
"eagle3"
,
"Qwen/Qwen3-8B"
,
"AngelSlim/Qwen3-8B_eagle3"
,
1
),
False
),
((
"eagle"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
((
"eagle"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
,
1
),
False
),
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
,
1
),
False
),
((
"eagle3"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
((
"eagle3"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
,
1
),
False
),
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
,
1
),
False
),
pytest
.
param
(
pytest
.
param
((
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
False
,
False
,
marks
=
large_gpu_mark
(
min_gb
=
80
)),
# works on 4x H100
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
pytest
.
param
((
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
pytest
.
param
(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
True
,
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
marks
=
large_gpu_mark
(
min_gb
=
80
)),
# works on 4x H100
True
,
((
"eagle"
,
"eagle618/deepseek-v3-random"
,
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
"eagle618/eagle-deepseek-v3-random"
,
1
),
False
),
((
"eagle"
,
"eagle618/deepseek-v3-random"
,
],
"eagle618/eagle-deepseek-v3-random"
,
1
),
False
),
ids
=
[
],
"qwen3_eagle3"
,
"llama3_eagle"
,
"llama3_eagle3"
,
"llama4_eagle"
,
ids
=
[
"llama4_eagle_mm"
,
"deepseek_eagle"
"qwen3_eagle3"
,
"llama3_eagle"
,
"llama3_eagle3"
,
])
"llama4_eagle"
,
"llama4_eagle_mm"
,
"deepseek_eagle"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
get_attn_backend_list_based_on_platform
())
def
test_eagle_correctness
(
def
test_eagle_correctness
(
...
@@ -174,9 +171,14 @@ def test_eagle_correctness(
...
@@ -174,9 +171,14 @@ def test_eagle_correctness(
model_setup: (method, model_name, eagle_model_name, tp_size)
model_setup: (method, model_name, eagle_model_name, tp_size)
'''
'''
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
if
"Llama-4-Scout"
in
model_setup
[
1
]
and
attn_backend
==
"FLASH_ATTN"
:
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
# Scout requires default backend selection
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
pass
else
:
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
(
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
()):
if
(
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
()):
pytest
.
skip
(
"TRITON_ATTN does not support "
pytest
.
skip
(
"TRITON_ATTN does not support "
...
...
vllm/v1/spec_decode/eagle.py
View file @
6f5c0931
...
@@ -804,6 +804,20 @@ class EagleProposer:
...
@@ -804,6 +804,20 @@ class EagleProposer:
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
if
self
.
is_multimodal_model
:
# Even if the target model is multimodal, we can also use
# text-only draft models
try
:
dummy_input_ids
=
torch
.
tensor
([[
1
]],
device
=
self
.
input_ids
.
device
)
self
.
model
.
get_input_embeddings
(
dummy_input_ids
,
multimodal_embeddings
=
None
)
except
(
NotImplementedError
,
AttributeError
,
TypeError
):
logger
.
warning
(
"Draft model does not support multimodal inputs, "
"falling back to text-only mode"
)
self
.
is_multimodal_model
=
False
if
supports_multimodal
(
target_model
):
if
supports_multimodal
(
target_model
):
# handle multimodality
# handle multimodality
self
.
model
.
config
.
image_token_index
=
(
self
.
model
.
config
.
image_token_index
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment