Unverified Commit cec7c288 authored by Rémi Delacourt's avatar Rémi Delacourt Committed by GitHub
Browse files

[Bugfix] Padded Eagle Specdec with Chunked Prefill (#26263)


Signed-off-by: default avatarRémi Delacourt <remi@mistral.ai>
Signed-off-by: default avatarRémi Delacourt <54138269+Flechman@users.noreply.github.com>
Signed-off-by: default avatarremi <remi@mistral.ai>
Co-authored-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent 18961c5e
...@@ -202,9 +202,9 @@ def test_speculators_model_integration( ...@@ -202,9 +202,9 @@ def test_speculators_model_integration(
@pytest.mark.parametrize( @pytest.mark.parametrize(
["model_setup", "mm_enabled"], ["model_setup", "mm_enabled", "chunked_prefill_enabled"],
[ [
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
pytest.param( pytest.param(
( (
"eagle3", "eagle3",
...@@ -213,11 +213,12 @@ def test_speculators_model_integration( ...@@ -213,11 +213,12 @@ def test_speculators_model_integration(
1, 1,
), ),
False, False,
False,
marks=pytest.mark.skip( marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32" reason="Skipping due to its head_dim not being a a multiple of 32"
), ),
), ),
( pytest.param(
( (
"eagle", "eagle",
"meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-8B-Instruct",
...@@ -225,7 +226,9 @@ def test_speculators_model_integration( ...@@ -225,7 +226,9 @@ def test_speculators_model_integration(
1, 1,
), ),
False, False,
), True,
marks=large_gpu_mark(min_gb=40),
), # works on 4x H100
( (
( (
"eagle3", "eagle3",
...@@ -234,6 +237,7 @@ def test_speculators_model_integration( ...@@ -234,6 +237,7 @@ def test_speculators_model_integration(
1, 1,
), ),
False, False,
False,
), ),
pytest.param( pytest.param(
( (
...@@ -243,6 +247,7 @@ def test_speculators_model_integration( ...@@ -243,6 +247,7 @@ def test_speculators_model_integration(
4, 4,
), ),
False, False,
False,
marks=large_gpu_mark(min_gb=80), marks=large_gpu_mark(min_gb=80),
), # works on 4x H100 ), # works on 4x H100
pytest.param( pytest.param(
...@@ -253,6 +258,7 @@ def test_speculators_model_integration( ...@@ -253,6 +258,7 @@ def test_speculators_model_integration(
4, 4,
), ),
True, True,
True,
marks=large_gpu_mark(min_gb=80), marks=large_gpu_mark(min_gb=80),
), # works on 4x H100 ), # works on 4x H100
( (
...@@ -263,6 +269,7 @@ def test_speculators_model_integration( ...@@ -263,6 +269,7 @@ def test_speculators_model_integration(
1, 1,
), ),
False, False,
False,
), ),
], ],
ids=[ ids=[
...@@ -281,6 +288,7 @@ def test_eagle_correctness( ...@@ -281,6 +288,7 @@ def test_eagle_correctness(
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int], model_setup: tuple[str, str, str, int],
mm_enabled: bool, mm_enabled: bool,
chunked_prefill_enabled: bool,
attn_backend: str, attn_backend: str,
): ):
if attn_backend == "TREE_ATTN": if attn_backend == "TREE_ATTN":
...@@ -317,9 +325,13 @@ def test_eagle_correctness( ...@@ -317,9 +325,13 @@ def test_eagle_correctness(
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
method, model_name, spec_model_name, tp_size = model_setup method, model_name, spec_model_name, tp_size = model_setup
max_model_len = 2048
max_num_batched_tokens = max_model_len
if chunked_prefill_enabled:
max_num_batched_tokens = 128
ref_llm = LLM( ref_llm = LLM(
model=model_name, max_model_len=2048, tensor_parallel_size=tp_size model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
) )
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
...@@ -334,9 +346,11 @@ def test_eagle_correctness( ...@@ -334,9 +346,11 @@ def test_eagle_correctness(
"method": method, "method": method,
"model": spec_model_name, "model": spec_model_name,
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"max_model_len": 2048, "max_model_len": max_model_len,
}, },
max_model_len=2048, max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=chunked_prefill_enabled,
) )
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment