Unverified Commit d5a16977 authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

[Dynamic Spec Decoding] Minor fix for disabling speculative decoding (#5000)

parent 325c1199
...@@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator, ...@@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size, batch_size,
max_output_len=output_len, max_output_len=output_len,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
...@@ -13,9 +13,9 @@ from vllm.spec_decode.top1_proposer import Top1Proposer ...@@ -13,9 +13,9 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
from .utils import create_batch, mock_worker from .utils import create_batch, mock_worker
@pytest.mark.parametrize('queue_size', [2, 4]) @pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6]) @pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10]) @pytest.mark.parametrize('k', [1])
@torch.inference_mode() @torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
"""Verify that speculative tokens are disabled when the batch size """Verify that speculative tokens are disabled when the batch size
...@@ -42,8 +42,12 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): ...@@ -42,8 +42,12 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
num_lookahead_slots=k, num_lookahead_slots=k,
running_queue_size=queue_size) running_queue_size=queue_size)
with pytest.raises(ValueError, match=exception_secret): if queue_size > disable_by_batch_size:
worker.execute_model(execute_model_req=execute_model_req) with patch.object(worker,
'_run_no_spec',
side_effect=ValueError(exception_secret)), \
pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# When the batch size is larger than the threshold, # When the batch size is larger than the threshold,
# we expect no speculative tokens (0). # we expect no speculative tokens (0).
......
...@@ -273,10 +273,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -273,10 +273,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._maybe_disable_speculative_tokens( self._maybe_disable_speculative_tokens(
disable_all_speculation, execute_model_req.seq_group_metadata_list) disable_all_speculation, execute_model_req.seq_group_metadata_list)
# If no spec tokens, call the proposer and scorer workers normally. # Speculative decoding is disabled in the following cases:
# Used for prefill. # 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# In any of these cases, the proposer and scorer workers
# are called normally.
if num_lookahead_slots == 0 or len( if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0: execute_model_req.seq_group_metadata_list
) == 0 or disable_all_speculation:
return self._run_no_spec(execute_model_req, return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation) skip_proposer=disable_all_speculation)
...@@ -316,8 +323,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -316,8 +323,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest, def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]: skip_proposer: bool) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to """Run a single generation step without any speculation. The input is
the proposer and scorer model so that the KV cache is consistent sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding. updated, so they cannot enable spec decode in the rest decoding.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment