Unverified Commit 62b8aebc authored by Cade Daniel's avatar Cade Daniel Committed by GitHub
Browse files

[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)

parent 050f285f
...@@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, ...@@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids, bonus_token_ids,
) )
# Bonus tokens are currently disabled. Verify they're set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
if which_tokens_accepted == "all_tokens_accepted": if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens. # Expect all tokens to be equal to draft tokens.
assert torch.equal(output_token_ids[:, :-1], draft_token_ids) assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
# Expect all bonus tokens to be included. # Expect all bonus tokens to be included.
assert torch.equal(output_token_ids[:, -1:], bonus_token_ids) assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
elif which_tokens_accepted == "no_tokens_accepted": elif which_tokens_accepted == "no_tokens_accepted":
# Expect first token to be equal to recovered tokens. # Expect first token to be equal to recovered tokens.
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0]) assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
...@@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, ...@@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
torch.ones_like(output_token_ids[:, 1:]) * -1) torch.ones_like(output_token_ids[:, 1:]) * -1)
elif which_tokens_accepted == "some_tokens_accepted": elif which_tokens_accepted == "some_tokens_accepted":
recovered_plus_bonus = torch.cat( recovered_plus_bonus = torch.cat(
(recovered_token_ids, bonus_token_ids), dim=-1) (recovered_token_ids, expected_bonus_token_ids), dim=-1)
# Assert first rejected token is a recovered token or bonus token. # Assert first rejected token is a recovered token or bonus token.
assert torch.equal( assert torch.equal(
recovered_plus_bonus[torch.arange(0, batch_size), recovered_plus_bonus[torch.arange(0, batch_size),
......
...@@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
def mock_sample(probs, *args, **kwargs): def mock_sample(probs, *args, **kwargs):
nonlocal sample_probs nonlocal sample_probs
sample_probs = probs sample_probs = probs
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None)
with patch("vllm.model_executor.layers.sampler._sample", mock_sample): with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata) sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
......
from typing import List, Tuple
import pytest import pytest
from tests.conftest import cleanup from tests.conftest import cleanup
...@@ -6,28 +8,34 @@ from vllm.model_executor.utils import set_random_seed ...@@ -6,28 +8,34 @@ from vllm.model_executor.utils import set_random_seed
@pytest.fixture @pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, def baseline_llm_generator(request, common_llm_kwargs,
baseline_llm_kwargs, seed): per_test_common_llm_kwargs, baseline_llm_kwargs,
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, seed):
return create_llm_generator("baseline", request, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, seed) baseline_llm_kwargs, seed)
@pytest.fixture @pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed): test_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, return create_llm_generator("test", request, common_llm_kwargs,
test_llm_kwargs, seed) per_test_common_llm_kwargs, test_llm_kwargs,
seed)
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
distinct_llm_kwargs, seed): per_test_common_llm_kwargs, distinct_llm_kwargs,
seed):
kwargs = { kwargs = {
**common_llm_kwargs, **common_llm_kwargs,
**per_test_common_llm_kwargs, **per_test_common_llm_kwargs,
**distinct_llm_kwargs, **distinct_llm_kwargs,
} }
test_name = request.node.name
def generator_inner(): def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = LLM(**kwargs) llm = LLM(**kwargs)
set_random_seed(seed) set_random_seed(seed)
...@@ -36,6 +44,23 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, ...@@ -36,6 +44,23 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
del llm del llm
cleanup() cleanup()
def generator_outer():
for llm in generator_inner(): for llm in generator_inner():
yield llm yield llm
del llm del llm
return generator_outer
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = []
token_ids = []
for llm in llm_generator():
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
del llm
return tokens, token_ids
import pytest
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_ray(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding and chunked prefill"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Speculative max model len > overridden max model len should raise.
"max_model_len": 128,
"speculative_max_model_len": 129,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1,
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
"speculative_max_model_len": 4096 + 1,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
"""Verify that speculative decoding validates speculative_max_model_len.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
"""Verify that speculative decoding with block manager v1 fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding requires usage of the V2"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
This diff is collapsed.
...@@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
num_draft_tokens = 0 num_draft_tokens = 0
k = 5 k = 5
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens( max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k) num_draft_tokens, k)
rej_sampler = MagicMock() rej_sampler = MagicMock()
...@@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
assert (metrics.draft_acceptance_rate == num_accepted_tokens / assert (metrics.draft_acceptance_rate == num_accepted_tokens /
num_draft_tokens) num_draft_tokens)
assert (metrics.system_efficiency == num_emitted_tokens / assert (metrics.system_efficiency == num_emitted_tokens /
num_possible_tokens) max_num_emitted_tokens)
else: else:
assert math.isnan(metrics.draft_acceptance_rate) assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency) assert math.isnan(metrics.system_efficiency)
...@@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations(): ...@@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations():
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([0, k]) assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k]) assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size]) assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)] assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
......
import random import random
from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
...@@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the target model with correct """Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out. inputs. Everything else is mocked out.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker() target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
...@@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
""" """
vocab_size = 32_000 vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) draft_worker = mock_worker(cls=MultiStepWorker,
target_worker = mock_worker(vocab_size=vocab_size) vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
...@@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
num_lookahead_slots=k) num_lookahead_slots=k)
assert len(rejection_sampler.call_args_list) == 1 assert len(rejection_sampler.call_args_list) == 1
args, _ = rejection_sampler.call_args_list[0] _, kwargs = rejection_sampler.call_args_list[0]
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs, actual = SimpleNamespace(**kwargs)
actual_proposal_token_ids) = args
assert torch.equal(actual_bonus_token_ids, assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:]) target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal( assert torch.equal(
actual_proposal_scores, actual.target_probs,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual_proposal_token_ids, proposal_token_ids) assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual_proposal_probs, proposal_probs) assert torch.equal(actual.draft_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
...@@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int):
""" """
vocab_size = 32_000 vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) draft_worker = mock_worker(cls=MultiStepWorker,
target_worker = mock_worker(vocab_size=vocab_size) vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
...@@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
""" """
vocab_size = 32_000 vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) draft_worker = mock_worker(cls=MultiStepWorker,
target_worker = mock_worker(vocab_size=vocab_size) vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
...@@ -500,8 +506,8 @@ def test_init_device(): ...@@ -500,8 +506,8 @@ def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization. well as other GPU initialization.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker() target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
......
...@@ -63,11 +63,14 @@ def create_execute_model_data( ...@@ -63,11 +63,14 @@ def create_execute_model_data(
def mock_worker(cls=None, def mock_worker(cls=None,
vocab_size: int = 30_000, vocab_size: int = 30_000,
max_model_len: int = 2048, max_model_len: int = 2048,
rank: int = 0) -> MagicMock: rank: int = 0,
use_spec: bool = True) -> MagicMock:
if cls is None: if cls is None:
cls = Worker cls = Worker
worker = MagicMock(spec=cls) spec = cls if use_spec else None
worker = MagicMock(spec=spec)
worker.vocab_size = vocab_size worker.vocab_size = vocab_size
worker.max_model_len = max_model_len worker.max_model_len = max_model_len
worker.rank = rank worker.rank = rank
......
...@@ -655,6 +655,9 @@ class SpeculativeConfig: ...@@ -655,6 +655,9 @@ class SpeculativeConfig:
target_dtype: str, target_dtype: str,
speculative_model: Optional[str], speculative_model: Optional[str],
num_speculative_tokens: Optional[int], num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
) -> Optional["SpeculativeConfig"]: ) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None. """Create a SpeculativeConfig if possible, else return None.
...@@ -672,6 +675,15 @@ class SpeculativeConfig: ...@@ -672,6 +675,15 @@ class SpeculativeConfig:
model, if provided. model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. tokens, if provided.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since its not
yet compatible with spec decode.
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
Returns: Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...@@ -690,12 +702,21 @@ class SpeculativeConfig: ...@@ -690,12 +702,21 @@ class SpeculativeConfig:
assert (speculative_model is not None assert (speculative_model is not None
and num_speculative_tokens is not None) and num_speculative_tokens is not None)
if enable_chunked_prefill:
raise ValueError(
"Speculative decoding and chunked prefill are "
f"currently mutually exclusive ({enable_chunked_prefill=}).")
if not use_v2_block_manager:
raise ValueError(
"Speculative decoding requires usage of the V2 "
"block manager. Enable it with --use-v2-block-manager.")
# TODO: The user should be able to specify revision/quantization/max # TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported. # model len for the draft model. It is not currently supported.
draft_revision = None draft_revision = None
draft_code_revision = None draft_code_revision = None
draft_quantization = None draft_quantization = None
draft_max_model_len = None
draft_model_config = ModelConfig( draft_model_config = ModelConfig(
model=speculative_model, model=speculative_model,
...@@ -707,7 +728,7 @@ class SpeculativeConfig: ...@@ -707,7 +728,7 @@ class SpeculativeConfig:
revision=draft_revision, revision=draft_revision,
code_revision=draft_code_revision, code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision, tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=draft_max_model_len, max_model_len=None,
quantization=draft_quantization, quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager, enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config. max_context_len_to_capture=target_model_config.
...@@ -715,6 +736,13 @@ class SpeculativeConfig: ...@@ -715,6 +736,13 @@ class SpeculativeConfig:
max_logprobs=target_model_config.max_logprobs, max_logprobs=target_model_config.max_logprobs,
) )
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
draft_model_config.max_model_len,
target_model_config.max_model_len,
))
draft_parallel_config = ( draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config( SpeculativeConfig.create_draft_parallel_config(
target_parallel_config)) target_parallel_config))
...@@ -725,6 +753,41 @@ class SpeculativeConfig: ...@@ -725,6 +753,41 @@ class SpeculativeConfig:
num_speculative_tokens, num_speculative_tokens,
) )
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int],
draft_max_model_len: int,
target_max_model_len: int,
) -> int:
"""Determine the max sequence len for the draft model. This is usually
the draft_max_model_len, but may be the target_max_model_len if it is
less than the draft_max_model_len, or may be speculative_max_model_len
if it is specified.
This is necessary so that sequences do not exceed the capacity of the
draft model or the target model.
speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""
if speculative_max_model_len is not None:
if speculative_max_model_len > draft_max_model_len:
raise ValueError(f"{speculative_max_model_len=} cannot be "
f"larger than {draft_max_model_len=}")
if speculative_max_model_len > target_max_model_len:
raise ValueError(f"{speculative_max_model_len=} cannot be "
f"larger than {target_max_model_len=}")
return speculative_max_model_len
return min(
draft_max_model_len,
target_max_model_len,
)
@staticmethod @staticmethod
def create_draft_parallel_config( def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig: target_parallel_config: ParallelConfig) -> ParallelConfig:
......
...@@ -73,6 +73,7 @@ class EngineArgs: ...@@ -73,6 +73,7 @@ class EngineArgs:
# Speculative decoding configuration. # Speculative decoding configuration.
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -237,7 +238,7 @@ class EngineArgs: ...@@ -237,7 +238,7 @@ class EngineArgs:
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32, 128], choices=[8, 16, 32],
help='Token block size for contiguous chunks of ' help='Token block size for contiguous chunks of '
'tokens.') 'tokens.')
...@@ -420,17 +421,25 @@ class EngineArgs: ...@@ -420,17 +421,25 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--speculative-model', '--speculative-model',
type=str, type=str,
default=None, default=EngineArgs.speculative_model,
help= help=
'The name of the draft model to be used in speculative decoding.') 'The name of the draft model to be used in speculative decoding.')
parser.add_argument( parser.add_argument(
'--num-speculative-tokens', '--num-speculative-tokens',
type=int, type=int,
default=None, default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from ' help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.') 'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-max-model-len',
type=str,
default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument('--model-loader-extra-config', parser.add_argument('--model-loader-extra-config',
type=str, type=str,
default=EngineArgs.model_loader_extra_config, default=EngineArgs.model_loader_extra_config,
...@@ -481,6 +490,9 @@ class EngineArgs: ...@@ -481,6 +490,9 @@ class EngineArgs:
target_dtype=self.dtype, target_dtype=self.dtype,
speculative_model=self.speculative_model, speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens, num_speculative_tokens=self.num_speculative_tokens,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
) )
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
......
...@@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest ...@@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroup) SequenceGroup, SequenceStage)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group) get_tokenizer_group)
...@@ -480,9 +480,12 @@ class LLMEngine: ...@@ -480,9 +480,12 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size) scheduled_seq_group.token_chunk_size)
# If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case. # If all sequences in the sequence group are in DECODE, then we can
if seq_group.get_num_uncomputed_tokens() == 0: # process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
self.output_processor.process_outputs(seq_group, outputs) self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
...@@ -569,7 +572,8 @@ class LLMEngine: ...@@ -569,7 +572,8 @@ class LLMEngine:
# Log stats. # Log stats.
if self.log_stats: if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs)) self.stat_logger.log(
self._get_stats(scheduler_outputs, model_output=output))
return request_outputs return request_outputs
...@@ -578,9 +582,18 @@ class LLMEngine: ...@@ -578,9 +582,18 @@ class LLMEngine:
if self.log_stats: if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs=None)) self.stat_logger.log(self._get_stats(scheduler_outputs=None))
def _get_stats(self, def _get_stats(
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: self,
"""Get Stats to be Logged to Prometheus.""" scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
"""Get Stats to be Logged to Prometheus.
Args:
scheduler_outputs: Optional, used to populate metrics related to
the scheduled batch,
model_output: Optional, used to emit speculative decoding metrics
which are created by the workers.
"""
now = time.time() now = time.time()
# KV Cache Usage in %. # KV Cache Usage in %.
...@@ -637,6 +650,14 @@ class LLMEngine: ...@@ -637,6 +650,14 @@ class LLMEngine:
time_to_first_tokens = time_last_iters if prompt_run else [] time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters time_per_output_tokens = [] if prompt_run else time_last_iters
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and (model_output[0].spec_decode_worker_metrics
is not None):
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None
return Stats( return Stats(
now=now, now=now,
num_running=num_running, num_running=num_running,
...@@ -649,6 +670,7 @@ class LLMEngine: ...@@ -649,6 +670,7 @@ class LLMEngine:
time_to_first_tokens=time_to_first_tokens, time_to_first_tokens=time_to_first_tokens,
time_per_output_tokens=time_per_output_tokens, time_per_output_tokens=time_per_output_tokens,
time_e2e_requests=time_e2e_requests, time_e2e_requests=time_e2e_requests,
spec_decode_metrics=spec_decode_metrics,
) )
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
......
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Protocol from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
import numpy as np import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
...@@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, ...@@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__) logger = init_logger(__name__)
disable_created_metrics() disable_created_metrics()
...@@ -118,6 +121,8 @@ class Stats: ...@@ -118,6 +121,8 @@ class Stats:
time_per_output_tokens: List[float] time_per_output_tokens: List[float]
time_e2e_requests: List[float] time_e2e_requests: List[float]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol): class SupportsMetricsInfo(Protocol):
...@@ -235,3 +240,19 @@ class StatLogger: ...@@ -235,3 +240,19 @@ class StatLogger:
self.num_prompt_tokens = [] self.num_prompt_tokens = []
self.num_generation_tokens = [] self.num_generation_tokens = []
self.last_local_log = stats.now self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
...@@ -83,6 +83,7 @@ class GPUExecutor(ExecutorBase): ...@@ -83,6 +83,7 @@ class GPUExecutor(ExecutorBase):
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
device_config=self.device_config, device_config=self.device_config,
cache_config=self.cache_config, cache_config=self.cache_config,
# TODO allow draft-model specific load config.
load_config=self.load_config, load_config=self.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
......
...@@ -144,6 +144,7 @@ class RejectionSampler(nn.Module): ...@@ -144,6 +144,7 @@ class RejectionSampler(nn.Module):
recovered_probs = self._get_recovered_probs( recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size) target_probs, draft_probs).reshape(batch_size * k, vocab_size)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(recovered_probs, recovered_token_ids = _multinomial(recovered_probs,
num_samples=1).reshape( num_samples=1).reshape(
batch_size, k) batch_size, k)
...@@ -307,6 +308,12 @@ class RejectionSampler(nn.Module): ...@@ -307,6 +308,12 @@ class RejectionSampler(nn.Module):
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1) bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids. # Fill the recovered token ids.
output.mul_(~after_false_mask).add_( output.mul_(~after_false_mask).add_(
recovered_token_ids.mul(after_false_mask)) recovered_token_ids.mul(after_false_mask))
......
...@@ -35,6 +35,14 @@ class Sampler(nn.Module): ...@@ -35,6 +35,14 @@ class Sampler(nn.Module):
in logits for each token in the input prompt. in logits for each token in the input prompt.
""" """
def __init__(self):
super().__init__()
# Whether or not the SamplerOutput should have on-device tensors
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self.include_gpu_probs_tensor = False
def forward( def forward(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
...@@ -79,13 +87,45 @@ class Sampler(nn.Module): ...@@ -79,13 +87,45 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata, sample_results, maybe_sampled_tokens_tensor = _sample(
sampling_tensors) probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
assert maybe_sampled_tokens_tensor is not None
sampled_tokens_tensor = maybe_sampled_tokens_tensor
on_device_tensors = (probs, sampled_tokens_tensor)
else:
on_device_tensors = None
# Get the logprobs query results. # Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs( prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results) logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results, sampling_metadata, return _build_sampler_output(sample_results,
prompt_logprobs, sample_logprobs) sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors)
@property
def _should_modify_greedy_probs_inplace(self) -> bool:
"""Whether or not the sampler should modify the probability distribution
of greedily-sampled tokens such that multinomial sampling would sample
the greedily-sampled token.
In other words, if True then we set the probability of the greedily-
sampled token to 1.
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return self.include_gpu_probs_tensor
def _get_bin_counts_and_mask( def _get_bin_counts_and_mask(
...@@ -359,7 +399,9 @@ def _sample_with_torch( ...@@ -359,7 +399,9 @@ def _sample_with_torch(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> List[Tuple[List[int], List[int]]]: include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
...@@ -371,6 +413,15 @@ def _sample_with_torch( ...@@ -371,6 +413,15 @@ def _sample_with_torch(
sample_metadata = {} sample_metadata = {}
multinomial_samples = {} multinomial_samples = {}
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
1,
dtype=torch.long,
device=logprobs.device)
else:
sampled_token_ids_tensor = None
# Counterintiutively, having two loops here is actually faster. # Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync. # The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType: for sampling_type in SamplingType:
...@@ -383,9 +434,25 @@ def _sample_with_torch( ...@@ -383,9 +434,25 @@ def _sample_with_torch(
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups, sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices) is_prompts, sample_indices)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices.long()], greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1) dim=-1)
if include_gpu_probs_tensor:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = greedy_samples.unsqueeze(-1)
if modify_greedy_probs:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace(logprobs, probs,
long_sample_indices,
greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1 max_best_of_in_batch = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts): for seq_group, is_prompt in zip(seq_groups, is_prompts):
...@@ -397,15 +464,23 @@ def _sample_with_torch( ...@@ -397,15 +464,23 @@ def _sample_with_torch(
"seq_groups": seq_groups, "seq_groups": seq_groups,
"generators": sampling_metadata.generators, "generators": sampling_metadata.generators,
} }
multinomial_samples[sampling_type] = _multinomial( multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices.long()], max_best_of_in_batch, probs[long_sample_indices], max_best_of_in_batch,
**seeded_args) **seeded_args)
if include_gpu_probs_tensor:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = multinomial_samples[sampling_type]
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices] beam_search_logprobs = logprobs[sample_indices]
else: else:
raise ValueError(f"Unsupported sampling type: {sampling_type}") raise ValueError(f"Unsupported sampling type: {sampling_type}")
# GPU<->CPU sync happens in the loop below. # GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
for sampling_type in SamplingType: for sampling_type in SamplingType:
if sampling_type not in sample_metadata: if sampling_type not in sample_metadata:
...@@ -427,7 +502,7 @@ def _sample_with_torch( ...@@ -427,7 +502,7 @@ def _sample_with_torch(
sample_results_dict[i] sample_results_dict[i]
for i in range(len(sampling_metadata.seq_groups)) for i in range(len(sampling_metadata.seq_groups))
] ]
return sample_results return sample_results, sampled_token_ids_tensor
def _sample_with_triton_kernel( def _sample_with_triton_kernel(
...@@ -511,12 +586,17 @@ def _sample_with_triton_kernel( ...@@ -511,12 +586,17 @@ def _sample_with_triton_kernel(
def _sample( def _sample(
probs: torch.Tensor, probs: torch.Tensor, logprobs: torch.Tensor,
logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
sampling_metadata: SamplingMetadata, include_gpu_probs_tensor: bool, modify_greedy_probs: bool
sampling_tensors: SamplingTensors, ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
) -> List[Tuple[List[int], List[int]]]: return _sample_with_torch(
return _sample_with_torch(probs, logprobs, sampling_metadata) probs,
logprobs,
sampling_metadata,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
# TODO: Enable once Triton kernel & associated code is faster. # TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
...@@ -680,12 +760,73 @@ def _get_logprobs( ...@@ -680,12 +760,73 @@ def _get_logprobs(
return result_prompt_logprobs, result_sample_logprobs return result_prompt_logprobs, result_sample_logprobs
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
sample_indices: torch.Tensor,
greedy_samples: torch.Tensor) -> None:
"""Modify the probability distributions of the greedily-sampled tokens such
that each sampled token has a "probability" of 1.0. This is required by
speculative decoding, which depends on the sampling method being encoded
within the probability distribution for correctness.
# Why do we only need to do this for greedy sampling?
vLLM's sampler performs the following steps for greedy or multinomial
(random) sampling:
1. Get logits from model.
2. Modify logits according to per-sequence sampling parameters.
- Multiply by temperature, top-k and top-p masking, penalize tokens
according to their frequency, etc.
3. Sample a token.
- Random sampling simply samples from the modified probability
distribution.
- Greedy sampling performs `argmax` to obtain the token with the
highest likelihood.
Ignoring greedy sampling for a moment, we find that the computed probability
distribution has the following property: we can sample from it independently
and find that the token sampled by the Sampler has a frequency corresponding
to how often we see it in our sampling. In other words, for tokens sampled
with vLLM's random SamplingType, the computed probability distribution
encodes the sampling methodology completely.
Greedy sampling does not normally have this property. vLLM modifies logits
according to sampling params, then performs `argmax`, then returns the
sampled token and the computed probability distribution. If we sample from
the distribution, we'll find the likelihood of the greedily-sampled token
is not always 1.0.
Since lossless speculative decoding requires that the sampling methodology
be encoded within the probability distribution, we are motivated to modify
the probability distribution such that the sampled token has probability 1
when speculative decoding is used.
NOTE: Alternatively, we could use an extremely low temperature to achieve
greedy sampling using multinomial computation and unite the codepaths. This
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
"""
logprobs[sample_indices, :] = -float('inf')
logprobs[sample_indices, greedy_samples] = 0.0
probs[sample_indices, :] = 0
probs[sample_indices, greedy_samples] = 1.0
def _build_sampler_output( def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]], sample_results: List[Tuple[List[int], List[int]]],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]], prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs], sample_logprobs: List[SampleLogprobs],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
) -> SamplerOutput: ) -> SamplerOutput:
"""Construct Python objects with the output of sampling.
Args:
on_device_tensors: Tuple containing on-device tensors with the
probabilities used in sampling and the sampled token ids. This
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output = [] sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs, for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups, group_sample_logprobs) in zip(sampling_metadata.seq_groups,
...@@ -701,4 +842,15 @@ def _build_sampler_output( ...@@ -701,4 +842,15 @@ def _build_sampler_output(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append( sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
return SamplerOutput(outputs=sampler_output)
# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
sampled_token_probs, sampled_token_ids = on_device_tensors
else:
sampled_token_probs, sampled_token_ids = (None, None)
return SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
)
...@@ -6,8 +6,8 @@ import torch ...@@ -6,8 +6,8 @@ import torch
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors, from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
nvtx_range, sampler_output_to_torch, sampler_output_to_torch,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
...@@ -72,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -72,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
proposal_lens_list = proposals.proposal_lens.tolist() proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist()
# Filter the list to ignore -1 proposals.
proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list
if -1 not in proposals
]
(spec_indices, non_spec_indices, target_seq_group_metadata_list, (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch( num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list, proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list, proposal_lens_list=proposal_lens_list,
) )
...@@ -89,7 +95,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -89,7 +95,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
target_sampler_output = target_sampler_output[0] target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch( all_tokens, all_probs = self._contract_batch(
original_bs=len(seq_group_metadata_list), contracted_bs=len(seq_group_metadata_list),
target_sampler_output=target_sampler_output, target_sampler_output=target_sampler_output,
proposals=proposals, proposals=proposals,
num_scoring_tokens=num_scoring_tokens, num_scoring_tokens=num_scoring_tokens,
...@@ -128,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -128,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
select_proposal_len_zero=True) select_proposal_len_zero=True)
target_seq_group_metadata_list = self._create_scoring_model_input( target_seq_group_metadata_list = self._create_scoring_model_input(
spec_seqs, proposal_token_ids_list) seq_group_metadata_list=spec_seqs,
proposal_token_ids=proposal_token_ids_list,
# NOTE: We determine the seq ids in the expanded batch using the
# full seq_group_metadata_list, instead of only spec_seqs.
target_seq_ids_iter=self._create_target_seq_id_iterator(
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
)
num_scoring_tokens = len(target_seq_group_metadata_list) num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs) target_seq_group_metadata_list.extend(non_spec_seqs)
return (spec_indices, non_spec_indices, target_seq_group_metadata_list, return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) num_scoring_tokens)
def _contract_batch(self, original_bs: int, def _contract_batch(self, contracted_bs: int,
target_sampler_output: List[SamplerOutput], target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int], num_scoring_tokens: int, non_spec_indices: List[int],
...@@ -144,42 +157,41 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -144,42 +157,41 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Contract the expanded batch back into its original size. """Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original This maps the scores of speculative tokens back to their original
sequences. sequences.
"""
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
maybe_mock_device_tensors(
sampler_output=target_sampler_output,
batch_size=len(non_spec_indices) + num_scoring_tokens,
vocab_size=self._vocab_size,
device=self._device,
)
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, non_spec_target_token_ids, (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output( non_spec_target_probs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens) target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token # Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1]. # of shape [batch_size * k + 1] back to [batch_size, k + 1].
batch_size, k = proposals.proposal_token_ids.shape expanded_batch_size, k = proposals.proposal_token_ids.shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.squeeze().reshape( target_token_ids = target_token_ids.squeeze().reshape(
batch_size, k + 1) spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(batch_size, k + 1, target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size) self._vocab_size)
all_tokens = torch.full(size=(original_bs, k + 1), all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1, fill_value=-1,
device=self._device, device=self._device,
dtype=torch.long) dtype=torch.long)
all_probs = torch.zeros(original_bs, all_probs = torch.zeros(contracted_bs,
k + 1, k + 1,
self._vocab_size, self._vocab_size,
device=self._device, device=self._device,
dtype=torch.float32) dtype=torch.float32)
if non_spec_indices: if non_spec_indices:
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs all_probs[non_spec_indices, :1, :] = non_spec_target_probs
if spec_indices: if spec_indices:
...@@ -192,17 +204,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -192,17 +204,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]: ) -> List[SequenceGroupMetadata]:
"""Given the original input sequences and proposed tokens from the draft """Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring. model, create a list of target sequences that can be used for scoring.
target_seq_ids_iter provides sequence ids for the expanded batch,
fulfilling the requirement that no seq id in the expanded batch is equal
to the seq id in the original batch.
""" """
if not seq_group_metadata_list: if not seq_group_metadata_list:
return [] return []
target_seq_ids_iter = self._create_target_seq_id_iterator(
get_all_seq_ids(seq_group_metadata_list))
target_seq_group_metadata = list( target_seq_group_metadata = list(
chain.from_iterable( chain.from_iterable(
self._create_target_seq_group_metadata( self._create_target_seq_group_metadata(
......
...@@ -24,9 +24,9 @@ class SpeculativeProposals: ...@@ -24,9 +24,9 @@ class SpeculativeProposals:
def __repr__(self): def __repr__(self):
return (f"SpeculativeProposals(" return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids.shape}, " f"proposal_token_ids={self.proposal_token_ids}, "
f"proposal_probs={self.proposal_probs.shape}, " f"proposal_probs={self.proposal_probs.shape}, "
f"proposal_lens={self.proposal_lens.shape})") f"proposal_lens={self.proposal_lens})")
@dataclass @dataclass
......
...@@ -147,15 +147,16 @@ class AsyncMetricsCollector: ...@@ -147,15 +147,16 @@ class AsyncMetricsCollector:
emitted_tokens = self._aggregate_num_emitted_tokens.item() emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens draft_tokens = self._aggregate_num_draft_tokens
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k) max_num_emitted_tokens = self.get_max_num_emitted_tokens(
draft_tokens, k)
if draft_tokens > 0: if draft_tokens > 0:
draft_acceptance_rate = accepted_tokens / draft_tokens draft_acceptance_rate = accepted_tokens / draft_tokens
else: else:
draft_acceptance_rate = float("nan") draft_acceptance_rate = float("nan")
if num_possible_tokens > 0: if max_num_emitted_tokens > 0:
system_efficiency = emitted_tokens / num_possible_tokens system_efficiency = emitted_tokens / max_num_emitted_tokens
else: else:
system_efficiency = float("nan") system_efficiency = float("nan")
...@@ -169,8 +170,22 @@ class AsyncMetricsCollector: ...@@ -169,8 +170,22 @@ class AsyncMetricsCollector:
) )
@staticmethod @staticmethod
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int: def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
# Divide by k since batch size can be variable. """Calculate the number of emitted tokens, assuming all tokens are
total_num_spec_seqs = draft_tokens / k accepted.
num_accepted_per_seq_if_all_accepted = k + 1
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted) This is equal to the number of sequences that have been speculated on,
times (speculation len + 1). The +1 comes from the bonus token.
"""
# Determine the number of sequences that have been speculated on. Since
# the batch size can be variable, we divide by k.
assert draft_tokens % k == 0
total_num_spec_seqs = draft_tokens // k
# A single sequence may emit k accepted tokens and one bonus token in
# the best case.
num_emitted_per_seq_if_all_accepted = k + 1
# The max num of emitted tokens is the number of speculated sequences
# times the max emitted per seq.
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment