"vllm/vscode:/vscode.git/clone" did not exist on "96266f119bb93516703328f9e37ec99cce45f792"
Commit 84dfdb17 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove unused code

parent f137e58c
from array import array
from typing import Callable, Dict, Mapping, Optional
from unittest.mock import patch
import pytest
import torch
import os
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
InputRegistry, ProcessorInputs, token_inputs)
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from ..models.utils import build_model_context
from ..utils import models_path_prefix
# Used for fast tests where the model doesn't matter
DUMMY_MODEL_ID = os.path.join(models_path_prefix, "facebook/opt-125m")
# Used for tests that need a multimodal model
MULTIMODAL_MODEL_ID = os.path.join(models_path_prefix, "OpenGVLab/InternVL2-2B")
# For mm_processor_kwargs - we test overrides by defining mocks for each place
# it is used, and ensuring that we can pass processor kwargs an override value
# to receive the intended result for things like sequence length etc.
DEFAULT_MAX_DYNAMIC_PATCH = 6
MAX_DYNAMIC_PATCH_OVERRIDE = 4
# Mocks for all of the places that we use the mm_processor_kwargs
# to override values in different callables
@pytest.fixture
def use_processor_mock():
"""Patches the internal model input processor with an override callable."""
def custom_processor(ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
# For testing purposes, we don't worry about the prompt
return token_inputs(
prompt_token_ids=[],
mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch})
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
return_value=custom_processor):
yield
@pytest.fixture
def use_dummy_data_mock():
"""Patches the internal model input processor with an override callable."""
def custom_dummy_data_factory(self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * max_dynamic_patch))
return DummyData(seq_data, None)
with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
custom_dummy_data_factory):
yield
# Lazy import to avoid CUDA reinitialization error
def mm_model_cls():
from vllm.model_executor.models.internvl import InternVLChatModel
return InternVLChatModel
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_max_dynamic_patch = lambda ctx, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: max_dynamic_patch # noqa: E501
custom_mapper = lambda ctx, data, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: { # noqa: E501
"pixel_values": torch.zeros(size=(1, max_dynamic_patch + 1, 3, 448, 448))
}
### Tests for default processor logic & mm_processor_kwargs wrapping
def test_default_processor_is_a_noop():
"""Ensure that by default, there is no processor override."""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(ctx.model_config)
proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
proc_outputs = processor(inputs=proc_inputs)
assert proc_inputs is proc_outputs
def _get_max_dynamic_patch_info(init_max_dynamic_patch: int,
inference_max_dynamic_patch: int):
"""Get the init / inference kwargs and expected max_dynamic_patch."""
# If we have a value for max_dynamic_patch, pass the override value and make
# sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value
init_kwargs = None if init_max_dynamic_patch is None else {
"max_dynamic_patch": init_max_dynamic_patch
}
inference_kwargs = None if inference_max_dynamic_patch is None else {
"max_dynamic_patch": inference_max_dynamic_patch
}
if inference_max_dynamic_patch is not None:
expected_seq_count = inference_max_dynamic_patch
elif init_max_dynamic_patch is not None:
expected_seq_count = init_max_dynamic_patch
else:
expected_seq_count = DEFAULT_MAX_DYNAMIC_PATCH
return init_kwargs, inference_kwargs, expected_seq_count
def _get_processed_max_dynamic_patch(
processor: Callable[[ProcessorInputs], ProcessorInputs],
inference_kwargs: Optional[Dict[str, int]],
) -> int:
processed_inputs = processor(
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
assert "type" in processed_inputs
assert processed_inputs["type"] == "token"
assert "mm_processor_kwargs" in processed_inputs
return processed_inputs["mm_processor_kwargs"]["max_dynamic_patch"]
@pytest.mark.parametrize(
"init_max_dynamic_patch,inference_max_dynamic_patch", [
(None, None),
(MAX_DYNAMIC_PATCH_OVERRIDE, None),
(DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
])
def test_input_processor_kwargs(use_processor_mock, init_max_dynamic_patch,
inference_max_dynamic_patch):
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()
(init_kwargs, inference_kwargs,
expected_seq_count) = _get_max_dynamic_patch_info(
init_max_dynamic_patch, inference_max_dynamic_patch)
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
max_dynamic_patch_val = _get_processed_max_dynamic_patch(
processor, inference_kwargs)
assert max_dynamic_patch_val == expected_seq_count
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
mm_processor_kwargs):
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
# Should filter out the init time kwargs
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
# Should filter out the inference time kwargs
max_dynamic_patch_val = _get_processed_max_dynamic_patch(
processor, mm_processor_kwargs)
assert max_dynamic_patch_val == DEFAULT_MAX_DYNAMIC_PATCH
### Test overrides for the dummy data
@pytest.mark.parametrize("max_dynamic_patch",
[None, MAX_DYNAMIC_PATCH_OVERRIDE])
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, max_dynamic_patch):
"""Ensure dummy data factories can use processor kwargs."""
mm_processor_kwargs = None if max_dynamic_patch is None else {
"max_dynamic_patch": max_dynamic_patch
}
expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
if max_dynamic_patch is None else max_dynamic_patch)
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
mm_processor_kwargs):
"""Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(
dummy_data.seq_data.prompt_token_ids) == DEFAULT_MAX_DYNAMIC_PATCH
### Test overrides for the max token count per multimodal instance
@pytest.mark.parametrize("max_dynamic_patch",
[None, MAX_DYNAMIC_PATCH_OVERRIDE])
def test_max_tokens_kwarg_overrides(max_dynamic_patch):
"""Ensure max token calcs can use processor kwargs."""
mm_processor_kwargs = None if max_dynamic_patch is None else {
"max_dynamic_patch": max_dynamic_patch
}
expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
if max_dynamic_patch is None else max_dynamic_patch)
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our max_dynamic_patch value back from the mm_processor_kwargs.
with patch.object(
mm_registry._get_plugin("image"),
"_max_mm_tokens",
{mm_model_cls(): get_max_dynamic_patch},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config)
assert expected_seq_count == max_multimodal_tokens
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Similar before, but since these kwargs get filtered,
# we always get our default value back.
with patch.object(
mm_registry._get_plugin("image"),
"_max_mm_tokens",
{mm_model_cls(): get_max_dynamic_patch},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config)
assert max_multimodal_tokens == DEFAULT_MAX_DYNAMIC_PATCH
### Test overrides for the mapper
@pytest.mark.parametrize(
"max_dynamic_patch",
[DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE])
def test_default_mapper_with_processor_kwargs(image_assets, max_dynamic_patch):
"""Ensure that the mapper processor kwargs can fall back to HF models."""
# NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed.
ctx = build_model_context(
MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch},
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image
mm_inputs = {"image": image}
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# pixel vals should have shape: [batch, max_dynamic_patch+1, ...]
assert mapped_inputs["pixel_values"].shape[1] == max_dynamic_patch + 1
@pytest.mark.parametrize(
"init_max_dynamic_patch,inference_max_dynamic_patch", [
(None, None),
(MAX_DYNAMIC_PATCH_OVERRIDE, None),
(DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
])
def test_custom_mapper_kwarg_overrides(image_assets, init_max_dynamic_patch,
inference_max_dynamic_patch):
"""Ensure custom mappers can use processor kwargs."""
(init_kwargs, inference_kwargs,
expected_seq_count) = _get_max_dynamic_patch_info(
init_max_dynamic_patch, inference_max_dynamic_patch)
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=init_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image
mm_inputs = {"image": image}
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our max_dynamic_patch value back from the mm_processor_kwargs.
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
mm_model_cls())
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
inference_kwargs)
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
mm_processor_kwargs):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
# Should filter out the init time kwargs
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image
mm_inputs = {"image": image}
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our max_dynamic_patch value back from the mm_processor_kwargs.
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
mm_model_cls())
# Should filter out the inference time kwargs
mapped_inputs = mm_registry.map_input(
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
assert mapped_inputs["pixel_values"].shape[1] == (
DEFAULT_MAX_DYNAMIC_PATCH + 1)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import os
import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest
from ..utils import models_path_prefix
MODEL_PATH = os.path.join(models_path_prefix, "bigscience/bloomz-560m")
PA_PATH = os.path.join(models_path_prefix, 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM')
def do_sample(llm, pa_name: str, pa_id: int):
prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])
outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_twitter_prompt_adapter(enforce_eager: bool):
llm = vllm.LLM(MODEL_PATH,
enforce_eager=enforce_eager,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
expected_output = ['complaint', 'no complaint']
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from ..utils import models_path_prefix
import os
MODEL_PATH = os.path.join(models_path_prefix, "bigscience/bloomz-560m")
pa_path = os.path.join(models_path_prefix, 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM')
pa_path2 = os.path.join(models_path_prefix, 'swapnilbp/angry_tweet_ptune')
def do_sample(engine):
prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from ..utils import models_path_prefix
import os
MODEL_PATH = os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-hf")
# pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
# lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
pa_path = os.path.join(models_path_prefix, "swapnilbp/llama_tweet_ptune")
lora_path = os.path.join(models_path_prefix, "yard1/llama-2-7b-sql-lora-test")
def do_sample(engine):
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)
expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for rejection sampling."""
import pytest
import torch
import torch.nn.functional as F
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This file tests V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def mock_causal_accepted_tensor(
k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
"""Generate an "accepted" tensor which should yield causally-accepted tokens
up to last accepted indices.
Tokens after last_accepted_indices+1 may also be accepted, although they
will not be causally accepted.
"""
batch_size = last_accepted_indices.shape[0]
accepted = (torch.arange(k).expand(batch_size, k)
<= last_accepted_indices.unsqueeze(-1).broadcast_to(
batch_size, k))
# Sprinkle accepted values after the contiguous initial accepted values.
# This replicates the behavior of rejection sampling, which may "accept"
# a token that cannot be accepted because of causality.
sprinkle_candidates = (torch.arange(k).expand(
batch_size,
k) > last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) +
1)
sprinkle = torch.rand(batch_size, k) > 0.5
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
return accepted
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False] if not current_platform.is_rocm() else [False])
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int,
device: str, use_flashinfer: bool):
"""Verify the output has correct format given predetermined accepted matrix.
"""
set_random_seed(seed)
torch.set_default_device(device)
batch_size = 10
k = 5
vocab_size = 3000
if which_tokens_accepted == "all_tokens_accepted":
accepted = mock_causal_accepted_tensor(
k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
elif which_tokens_accepted == "no_tokens_accepted":
accepted = mock_causal_accepted_tensor(
k, -torch.ones((batch_size, ), dtype=torch.long))
elif which_tokens_accepted == "some_tokens_accepted":
last_accepted_indices = torch.randint(low=-1,
high=k,
size=(batch_size, ))
accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
else:
raise AssertionError()
recovered_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
recovered_token_ids,
draft_token_ids,
bonus_token_ids,
)
expected_bonus_token_ids = bonus_token_ids.clone()
if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
# Expect all bonus tokens to be included.
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
elif which_tokens_accepted == "no_tokens_accepted":
# Expect first token to be equal to recovered tokens.
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
# Expect everything else to be -1.
assert torch.equal(output_token_ids[:, 1:],
torch.ones_like(output_token_ids[:, 1:]) * -1)
elif which_tokens_accepted == "some_tokens_accepted":
recovered_plus_bonus = torch.cat(
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
# Assert first rejected token is a recovered token or bonus token.
assert torch.equal(
recovered_plus_bonus[torch.arange(0, batch_size),
last_accepted_indices + 1],
output_token_ids[torch.arange(0, batch_size),
last_accepted_indices + 1])
# Assert every subsequent token is -1.
subsequent_mask = torch.arange(0, k + 1).expand(
batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
assert torch.all(output_token_ids[subsequent_mask] == -1)
@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False] if not current_platform.is_rocm() else [False])
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.parametrize("use_flashinfer", [True, False])
# Not testing FlashInfer now, since 0.2.3 API removed the ability
# to pass in uniform samples.
@pytest.mark.parametrize("use_flashinfer", [False])
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int, device: str,
use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
results = []
for _ in range(n_rep):
seeded_seqs = {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size) if seeded_mask[i]
}
results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, seeded_seqs))
for i in range(batch_size):
if seeded_mask[i]:
for j in range(1, n_rep):
assert torch.equal(results[j][i], results[0][i])
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Consistent with NV.")
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.parametrize("use_flashinfer", [True, False])
# Not testing FlashInfer now, since 0.2.3 API removed the ability
# to pass in uniform samples.
@pytest.mark.parametrize("use_flashinfer", [False])
@torch.inference_mode()
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
torch.set_default_device(device)
set_random_seed(0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
single_batches = []
for i in range(batch_size):
single_batches.append((draft_probs[i].clone().unsqueeze(0),
draft_token_ids[i].clone().unsqueeze(0),
target_probs[i].clone().unsqueeze(0),
bonus_token_ids[i].clone().unsqueeze(0),
draft_token_ids[i].clone().unsqueeze(0)))
set_random_seed(0)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
results = []
seeded_seqs = {
i: torch.Generator(device=device).manual_seed(i)
for i in range(1, batch_size) # 0 is seed None
}
batch_result = rejection_sampler(target_probs.clone(),
bonus_token_ids.clone(),
draft_probs.clone(),
draft_token_ids.clone(), seeded_seqs)
set_random_seed(0)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
for i in range(batch_size):
request_seeded_seqs = {
0: torch.Generator(device=device).manual_seed(i)
} if seeded_seqs.get(i) is not None else None
(draft_probs, draft_token_ids, target_probs, bonus_token_ids,
draft_token_ids) = single_batches[i]
results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, request_seeded_seqs))
for i in range(batch_size):
assert torch.equal(batch_result[i], results[i].squeeze(0))
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Rocm platform does not support flashinfer.")
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
batch_size: int, device: str):
"""
Test the flashinfer and nonflashinfer backend generate
the same output metrics.
"""
pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed "
"the ability to pass in uniform samples.")
torch.set_default_device(device)
torch.manual_seed(0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
num_accepted_tokens = []
num_emitted_tokens = []
num_draft_tokens = []
def get_seeded_seqs():
return {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size)
}
for use_flashinfer in [True, False] if not current_platform.is_rocm() else [False]:
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
seeded_seqs = get_seeded_seqs()
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, seeded_seqs)
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
assert num_accepted_tokens[0] == num_accepted_tokens[1]
assert num_emitted_tokens[0] == num_emitted_tokens[1]
assert num_draft_tokens[0] == num_draft_tokens[1]
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False] if not current_platform.is_rocm() else [False])
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str, device: str,
use_flashinfer: bool):
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
oob_token_ids = None
if which_token_ids == "bonus_token_ids":
oob_token_ids = bonus_token_ids
elif which_token_ids == "draft_token_ids":
oob_token_ids = draft_token_ids
else:
raise AssertionError()
if above_or_below_vocab_range == "above":
rogue_token_id = vocab_size + 1
elif above_or_below_vocab_range == "below":
rogue_token_id = -1
else:
raise AssertionError()
oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError):
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@pytest.mark.parametrize("seed", list(range(5)))
@pytest.mark.parametrize("use_flashinfer", [True, False] if not current_platform.is_rocm() else [False])
@torch.inference_mode()
def test_rejection_sampling_approximates_target_distribution(
seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
This is done by first creating a random target probability
distribution and a random draft probability distribution. We then
sample token ids from the rejection sampler using these draft
and target distributions. The samples are used to estimate
the output probability distribution, which we expect to approximate
the target distribution.
A basic distance metric is used to determine similarity between
distributions.
We expect that as we increase the number of samples,
the distance between the observed distribution and the target
distribution decreases. To measure this, we compare the distance
of the observed distribution against both the target distribution
and a uniform random distribution. We expect the distance between
the observed distribution and the target distribution to improve
much more than the distance improvement between the observed
distribution and the random distribution.
When draft_and_target_probs_equal=True, the draft and target
probabilities are exactly equal. Rejection sampling should
still work without any NaNs or exceptions.
"""
torch.set_default_device("cpu")
set_random_seed(seed)
helper = _CorrectnessTestHelper(
vocab_size=10,
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
)
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
draft_and_target_probs_equal)
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference: list[float] = []
distance_wrt_target: list[float] = []
for num_samples in sample_sizes:
(reference_vs_rejsample_dist,
target_vs_rejsample_dist) = helper.run_and_compare_distributions(
draft_probs,
target_probs,
reference_probs,
num_samples,
)
distance_wrt_reference.append(reference_vs_rejsample_dist)
distance_wrt_target.append(target_vs_rejsample_dist)
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
f"{reference_vs_rejsample_dist=:.05f}")
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
f"{relative_change_in_distance_wrt_reference=:.02f}")
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
expected_improvement_multiplier = 20
assert (relative_change_in_distance_wrt_target
> relative_change_in_distance_wrt_reference *
expected_improvement_multiplier)
def get_ratio_first_to_last(elements: list[float]) -> float:
return elements[0] / elements[-1]
class _CorrectnessTestHelper:
"""Class that packages together logic required for the unit-level
rejection sampling correctness test.
"""
def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
self.rejection_sampler = rejection_sampler
self.vocab_size = vocab_size
self.vocab_range = (0, vocab_size)
self.rejection_sampler.init_gpu_tensors(device=0)
# Keep test simple, use k=1
self.k = 1
# Bonus tokens not used, but rejection sampler requires
# correct shape.
self.num_bonus_tokens = 1
def generate_probs_for_test(
self, draft_and_target_probs_equal: bool
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
draft_probs, target_probs = (F.softmax(
torch.rand(self.vocab_size, dtype=torch.float32),
dim=-1,
) for _ in range(2))
num_reference_probs = 100
reference_probs = F.softmax(
torch.rand(num_reference_probs,
self.vocab_size,
dtype=torch.float32),
dim=-1,
)
if draft_and_target_probs_equal:
target_probs = draft_probs.clone()
return draft_probs, target_probs, reference_probs
def run_and_compare_distributions(self, draft_probs: torch.Tensor,
target_probs: torch.Tensor,
reference_probs: torch.Tensor,
num_samples: int) -> tuple[float, float]:
# Sample using rejection sampling.
rej_sample_probs = self._estimate_rejection_sampling_pdf(
draft_probs, target_probs, num_samples)
# Average distance from reference probs.
reference_vs_rejsample_dist = torch.dist(
reference_probs,
rej_sample_probs).item() / reference_probs.shape[0]
target_vs_rejsample_dist = torch.dist(target_probs,
rej_sample_probs).item()
return reference_vs_rejsample_dist, target_vs_rejsample_dist
def _estimate_rejection_sampling_pdf(
self,
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
num_samples: int,
) -> torch.Tensor:
# Repeat draft probs num_samples times.
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
num_samples, 1, 1)
# Repeat target probs num_samples * (k + 1) times.
# Rejection sampler requires bonus token probs, but they aren't used.
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
num_samples, self.k + 1, 1)
# Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
num_samples=1,
replacement=True).reshape(
num_samples, self.k)
# Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
dtype=torch.int64,
device="cuda").repeat(num_samples, 1)
# Get output tokens via rejection sampling.
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
bonus_token_ids.to("cuda"),
draft_probs.to("cuda"),
draft_token_ids.to("cuda"))
# Remove bonus tokens
output_token_ids = output_token_ids[:, :-1].flatten()
# Estimate probability density function
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
device="cpu"),
bins=self.vocab_size,
range=self.vocab_range,
density=True)
return hist.hist
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import os
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
from ...utils import models_path_prefix
@pytest.mark.parametrize("common_llm_kwargs",
[{
"model": os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Speculative max model len > overridden max model len should raise.
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"max_model_len": 129,
},
"max_model_len": 128,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"max_model_len": 2048 + 1,
},
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"max_model_len": 131072 + 1,
},
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
"""Verify that speculative decoding validates speculative_max_model_len.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, EAGLE would not break the
correctness for the target model outputs.
"""
import pytest
import os
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
import vllm.envs as envs
os.environ["LLAMA_NN"] = "0"
# main model
MAIN_MODEL = os.path.join(models_path_prefix, "JackFram/llama-68m")
# speculative model
SPEC_MODEL = os.path.join(models_path_prefix, "abhigoyal/vllm-eagle-llama-68m-random")
# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 4
# precision
PRECISION = "float32" if envs.VLLM_USE_TRITON_FLASH_ATTN else "half"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-chat-hf"),
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "yuhuili/EAGLE-llama2-chat-7B"),
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# 2 for small prompt, 256//16 for generated.
"num_gpu_blocks_override": 2 + 256 // 16,
"max_model_len": (2 + 256 // 16) * 16,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "yuhuili/EAGLE-LLaMA3-Instruct-8B"),
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# 2 for small prompt, 256//16 for generated.
"num_gpu_blocks_override": 2 + 256 // 16,
"max_model_len": (2 + 256 // 16) * 16,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": os.path.join(models_path_prefix, "Qwen/Qwen2-7B-Instruct"),
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "yuhuili/EAGLE-Qwen2-7B-Instruct"),
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_qwen2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
if __name__ == "__main__":
import pytest
pytest.main([__file__])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests which cover integration of the speculative decoding framework with
other features, e.g. cuda graphs.
"""
import pytest
import os
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
os.environ["LLAMA_NN"] = "0"
MAIN_MODEL = os.path.join(models_path_prefix, "JackFram/llama-68m")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix,"JackFram/llama-68m"),
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
# Explicitly specify draft model quantization
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"),
"num_speculative_tokens": 5,
"quantization": "gptq",
},
},
# Explicitly specify GPTQ-based draft model to use marlin quantization
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"),
"num_speculative_tokens": 5,
"quantization": "marlin",
},
},
# Not explicitly specify draft model quantization
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"),
"num_speculative_tokens": 5,
"quantization": None,
},
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int, seed: int):
"""Verify spec decode works well with draft model quantization configs.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""
import json
from typing import Optional
import pytest
import torch
import os
from vllm.platforms import current_platform
from .conftest import run_equality_correctness_test_tp
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
os.environ["LLAMA_NN"] = "0"
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
"--tensor-parallel-size",
"2"
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("test_llm_kwargs", [
[
"--speculative_config",
json.dumps({
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 3,
}),
],
[
"--speculative_config",
json.dumps({
"model": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
}),
],
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if current_platform.is_rocm():
pytest.skip("hip is not well-supported yet")
run_equality_correctness_test_tp(os.path.join(models_path_prefix, "JackFram/llama-68m"),
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"tensor_parallel_size": 2,
# Precision
"dtype": "bfloat16",
# Main model
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
},
}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
seed: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, max_output_len=32, seed=seed,
temperature=0.0)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"tensor_parallel_size": 2,
# Precision
"dtype": "bfloat16",
# Main model
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[{
"enable_chunked_prefill": False,
"max_num_batched_tokens": 32,
"max_model_len": 32,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
},
}])
@pytest.mark.parametrize("logprobs", [None])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
logprobs: Optional[int],
batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, max_output_len=32, seed=seed,
temperature=0.0)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"tensor_parallel_size": 2,
# Precision
"dtype": "bfloat16",
# Main model
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[{
"enable_chunked_prefill": False,
"max_num_batched_tokens": 32,
"max_model_len": 32,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
},
}])
@pytest.mark.parametrize("logprobs", [2])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2_with_logprobs(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, logprobs: Optional[int],
batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0,
logprobs=logprobs)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""
import json
import openai
import pytest
import torch
import os
from .conftest import run_equality_correctness_test_tp
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
os.environ["LLAMA_NN"] = "0"
MAIN_MODEL = os.path.join(models_path_prefix, "JackFram/llama-68m")
SPEC_MODEL = os.path.join(models_path_prefix, "JackFram/llama-68m")
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"tensor_parallel_size": 4,
# Precision
"dtype": "bfloat16",
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
},
}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp4(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
seed: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, max_output_len=32, seed=seed,
temperature=0.0)
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"tensor_parallel_size": 4,
# Precision
"dtype": "bfloat16",
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": 5,
"max_model_len": 32,
},
}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
def test_skip_speculation(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify job failure with RuntimeError when all sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
TODO: fix it to pass without raising Error. (#5814)
"""
with pytest.raises(RuntimeError):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, max_output_len=output_len, seed=seed,
temperature=0.0)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import cycle
import pytest
import os
from vllm import SamplingParams
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
os.environ["LLAMA_NN"] = "0"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
7,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 8])
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int, prefill_chunk_size: int):
"""Verify output logprobs are equal with and without speculative decoding,
as well as with and without chunked prefill.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
"num_speculative_tokens": 6,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int, logprobs: int):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
"num_speculative_tokens": 3,
"disable_logprobs": False,
# Artificially limit the draft model max model len; this forces
# vLLM to skip speculation once the sequences grow beyond 32-k
# tokens.
"max_model_len": 32,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1])
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [6])
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Verify at least one logprob result has num_logprobs+1, which tests the
case where the sampled token is not in top-k logprobs.
Ideally, this test should validate equality with non-spec by getting
logprobs. This is left as future improvement.
"""
temperature = 1.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
logprobs=logprobs,
)
sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
with vllm_runner(**sd_args) as vllm_model:
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
num_returned_logprobs = [
len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
]
# Assert one of the returned logprobs has > num_logprobs (indicating the
# sampled token is not in top-k).
assert any(
[num_returned > logprobs for num_returned in num_returned_logprobs])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix,"JackFram/llama-160m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("logprobs", [0])
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, Medusa would not break the
correctness for the target model outputs.
"""
import pytest
import os
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
os.environ["LLAMA_NN"] = "0"
# main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
# OOM in CI pipeline, so using a smaller model.
MAIN_MODEL = os.path.join(models_path_prefix, "JackFram/llama-68m")
# speculative model
SPEC_MODEL = os.path.join(models_path_prefix, "abhigoyal/vllm-medusa-llama-68m-random")
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 5
# precision
PRECISION = "float16"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
8,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, logprobs: int,
prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 8])
def test_medusa_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int,
prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int, prefill_chunk_size: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
if __name__ == "__main__":
import pytest
pytest.main([__file__])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, MLPSpeculator would not break the
correctness for the target model outputs.
"""
from unittest.mock import patch
import pytest
import os
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
# main model
MAIN_MODEL = os.path.join(models_path_prefix, "JackFram/llama-160m")
# speculative model
SPEC_MODEL = os.path.join(models_path_prefix, "ibm-fms/llama-160m-accelerator")
# max. number of speculative tokens: this corresponds to
# n_predict in the config.json of the speculator model.
MAX_SPEC_TOKENS = 3
# precision
PRECISION = "float16"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [4, 4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [8])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# NOTE Test is sensitive enough st if we don't enable chunked prefill
# scheduling on baseline too, we get slightly different logprobs, ending
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify acceptance rate with different batch size and large output
length."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=0.0,
seed=seed,
expected_acceptance_rate=0.48)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Speculative config
"speculative_config": {
"model": SPEC_MODEL,
},
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("temperature", [1.0])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
temperature: float,
prefill_chunk_size: int, seed: int):
"""Verify seeded runs produce the same output."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
seed=seed)
# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
seed=seed,
disable_seed=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_greedy_correctness_with_padding(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify greedy equality when the vocab dimension is padded
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# Default pad_to is 64, test model has vocab_size of 32000
def patched_pad_vocab_size(vocab_size, pad_to=None):
return pad_vocab_size(vocab_size, pad_to=32064)
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
patched_pad_vocab_size):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
# Speculative decoding is disabled when sequences reach decoding and the batch
# consists of single-token requests. Hence we set `max_num_seqs`
# >= `speculative_disable_by_batch_size` to test feature interaction.
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int,
output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, prefill_chunk_size: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, mtp would not break the
correctness for the target model outputs.
"""
import os
import pytest
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
os.environ["VLLM_MLA_DISABLE"] = "1"
# main model
MAIN_MODEL = os.path.join(models_path_prefix, "luccafong/deepseek_mtp_main_random_bf16")
# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS = 1
# precision
PRECISION = "bfloat16"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
output_len: int, seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
if __name__ == "__main__":
import pytest
pytest.main([__file__])
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""The tests in this file verify end-to-end speculative decoding correctness.
This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model). Similarly, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
same input. vLLM largely guarantees this.
@cadedaniel has seen cases where the output probabilities of a draft/target
model change slightly with certain batch sizes or prompts, even with Torch
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
determinism in on-device batched operations, a bug in vLLM's spec decode
implementation, or the "hardware numerics" limitations. Either way, rejection
sampling ensures the output distribution matches the target model, but it breaks
greedy-equality tests for those batch sizes/prompts.
"""
from itertools import cycle
import pytest
import os
from transformers import AutoTokenizer
from vllm import SamplingParams
from ...utils import create_new_process_for_each_test
from .conftest import (get_output_from_llm_generator,
run_equality_correctness_test)
from ...utils import models_path_prefix
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
{
# Verify the detokenizer assertions in the test work when spec
# decode is disabled.
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
batch_size: int):
"""Run generation with speculative decoding on a batch. Verify the engine
generates the correct number of tokens (via ignore_eos=True), and that the
detokenization matches HF transformers.
"""
output_len = 32
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
# Expect a generation for each prompt in the batch.
assert len(batch_token_ids) == len(prompts)
# Expect each generation to have expected number of tokens (note ignore_eos
# is True).
assert [len(token_ids)
for token_ids in batch_token_ids] == ([output_len] * batch_size)
# Expect detokenized string to match.
tok = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "JackFram/llama-68m"))
for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
expected_tokens = tok.decode(actual_token_ids)
print(f"{actual_token_ids=}")
assert actual_tokens.strip() == expected_tokens.strip()
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
},
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"disable_logprobs": False,
},
"enable_chunked_prefill": False,
}, {
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
}])
@pytest.mark.parametrize(
"output_len",
[
# Use long output len for the small model test.
10,
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a tiny model with batch size of one.
Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
When the draft model is the same as the target model, we further check
whether all speculative tokens are accepted.
"""
ensure_all_accepted = per_test_common_llm_kwargs.get(
"model_name") == test_llm_kwargs.get("speculative_config")["model"]
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
prompt_logprobs=2,
logprobs=2,
disable_logprobs=False,
temperature=0.0,
ensure_all_accepted=ensure_all_accepted)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
},
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [64])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a tiny model and large batch size.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
},
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("max_output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
max_output_len: int, seed: int):
"""Verify greedy equality on a tiny model, with a large batch size, and when
sampling respects the EOS token.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len,
seed=seed,
temperature=0.0,
ignore_eos=False)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model_name": os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-chat-hf"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
[
# Use decently long output len for a high quality test.
256,
])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a "real" model and batch size of 1. This is
separate from large BS tests to make identifying the source of bugs easier.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model_name": os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-chat-hf"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with a "real" model on a nontrivial batch size.
This is the closest test to a real production workload.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_spec_decode_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# https://github.com/triton-lang/triton/issues/2266 tl.dot
# doesn't support embedding < 16
{
"block_size": 16,
},
{
"block_size": 32,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality over different block sizes.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_skip_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality when some (or all) sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_disable_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality when all sequences disable speculation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": k,
},
"enable_chunked_prefill": False,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
] + [{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": k,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with many different values of k.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": False
}
# Try a range of common k.
for k in [1, 2, 3]
] + [{
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
"num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
} for k in [1, 2, 3]])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@create_new_process_for_each_test()
def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the
correctness for the target model outputs.
"""
import pytest
import os
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
os.environ["LLAMA_NN"] = "0"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": False,
},
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
8,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=0,
seed=seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 3,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 1,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4
},
}, {
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
# GPU memory utilization
"gpu_memory_utilization": 0.6
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import os
from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
# main model
MAIN_MODEL = os.path.join(models_path_prefix, "JackFram/llama-68m")
# speculative model
SPEC_MODEL = os.path.join(models_path_prefix, "JackFram/llama-160m")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# speculative config
"speculative_config": {
"model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
"num_speculative_tokens": 3,
},
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("batch_size", [1, 8, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
20,
])
def test_seeded_consistency(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
temperature: float, output_len: int):
"""Verify outputs are consistent across multiple runs with same seed
"""
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
disable_seed=False,
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
disable_seed=True,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from unittest.mock import MagicMock
import pytest
import torch
from vllm.spec_decode.metrics import AsyncMetricsCollector
def test_initial_call_returns_none():
"""Expect first call to get metrics to return None.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None
def test_second_call_returns_metrics():
"""Expect second call to not return None.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
@pytest.mark.parametrize("rank", [1, 2, 3, 4])
def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=rank)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
collect_interval_s + 0.1, collect_interval_s + 0.1
]
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
def test_timer_is_reset():
"""Verify that the internal timer inside AsyncMetricsCollector
is reset after collection.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0,
collect_interval_s + 0.1,
collect_interval_s + 0.1,
collect_interval_s + 0.2,
collect_interval_s + 0.2,
2 * collect_interval_s + 0.1,
2 * collect_interval_s + 0.1,
]
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
@pytest.mark.parametrize("has_data", [True, False])
def test_initial_metrics_has_correct_values(has_data: bool):
"""Test correctness of metrics data.
"""
if has_data:
num_accepted_tokens = 103
num_emitted_tokens = 104
num_draft_tokens = 105
else:
num_accepted_tokens = 0
num_emitted_tokens = 0
num_draft_tokens = 0
k = 5
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k)
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = num_draft_tokens
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k)
metrics = collector.maybe_collect_rejsample_metrics(k)
assert metrics.num_spec_tokens == k
assert metrics.accepted_tokens == num_accepted_tokens
assert metrics.draft_tokens == num_draft_tokens
assert metrics.emitted_tokens == num_emitted_tokens
if has_data:
assert (metrics.draft_acceptance_rate == num_accepted_tokens /
num_draft_tokens)
assert (metrics.system_efficiency == num_emitted_tokens /
max_num_emitted_tokens)
else:
assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from unittest.mock import MagicMock
import pytest
import torch
import os
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
get_all_seq_ids)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from .utils import (assert_logprobs_dict_allclose, create_batch,
create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache)
from ..utils import models_path_prefix
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
def test_assert_enough_kv_space(num_steps: int):
"""Test that the multi step worker checks for sufficient space in the KV
cache. It should throw if it cannot run all the steps.
"""
block_size = 16
num_gpu_blocks = 2048 // block_size
prompts = [
list(range(block_size * 3)),
list(range(block_size * 2)),
]
prev_output_tokens = [
list(range(block_size * 1)),
list(range(block_size * 2)),
]
final_prompt_lens = [
len(prompt + output) + num_steps
for prompt, output in zip(prompts, prev_output_tokens)
]
inputs = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens,
continuations=prev_output_tokens)
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
worker = MagicMock()
worker.model_runner.block_size = block_size
for seq_group_metadata in inputs:
original_block_tables = seq_group_metadata.block_tables
# No exception.
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = {
seq_id: []
for seq_id, physical_blocks in original_block_tables.items()
}
# Expect exception.
with pytest.raises(ValueError,
match='times but found insufficient KV space for'):
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = original_block_tables
@torch.inference_mode()
def test_same_output_for_single_step():
"""Verify the multi step worker produces the same output as the normal
worker for num_steps=1.
"""
seed = 100
model_name = os.path.join(models_path_prefix, 'JackFram/llama-68m')
block_size = 32
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
# multi_step_worker.model_runner = worker.model_runner
# multi_step_worker.cache_engine = worker.cache_engine
num_steps = 1
prompts = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
actual_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
assert len(actual_output) == num_steps
actual_output = actual_output[0]
single_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(worker.cache_engine)
set_random_seed(seed)
expected_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=single_step_seq_group))[0]
actual_token_ids = [
output.samples[0].output_token for output in actual_output
]
actual_logprobs = [output.samples[0].logprobs for output in actual_output]
expected_token_ids = [
output.samples[0].output_token for output in expected_output
]
expected_logprobs = [
output.samples[0].logprobs for output in expected_output
]
assert actual_token_ids == expected_token_ids
print(f'{actual_logprobs=}')
print(f'{expected_logprobs=}')
assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
@torch.inference_mode()
def test_same_output_for_multi_step():
"""Verify the multi-step worker produces the same output as the normal
worker when num_steps > 1. This test runs the multi-step worker once, and
then runs the worker num_steps times, and compares the output.
"""
seed = 100
model_name = os.path.join(models_path_prefix, 'JackFram/llama-68m')
block_size = 16
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
# Make sure we go over the block boundary.
num_steps = block_size + 1
random.seed(seed)
prompts = [[
random.randint(0, 1000) for _ in range(random.randint(10, 20))
] for _ in range(10)]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
continuations = [[1] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
single_step_output: list[SamplerOutput] = []
continuations = [[1] for _ in prompts]
set_random_seed(seed)
for _ in multi_step_output:
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[]
for _ in prompts]
single_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
single_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output,
single_step_output):
multi_step_output_token_ids[i].append(
multi_step[i].samples[0].output_token)
single_step_output_token_ids[i].append(
single_step[i].samples[0].output_token)
multi_step_output_logprobs[i].append(
multi_step[i].samples[0].logprobs)
single_step_output_logprobs[i].append(
single_step[i].samples[0].logprobs)
# Print per-sequence token ids
for i, (multi_step_tokens, single_step_tokens) in enumerate(
zip(multi_step_output_token_ids, single_step_output_token_ids)):
print(f'{i=} {multi_step_tokens=}')
print(f'{i=} {single_step_tokens=}')
print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
# Assert token ids are equal.
for multi_step_tokens, single_step_tokens in zip(
multi_step_output_token_ids, single_step_output_token_ids):
assert multi_step_tokens == single_step_tokens
# Assert logprobs are equal.
for multi_step_logprobs, single_step_logprobs in zip(
multi_step_output_logprobs, single_step_output_logprobs):
assert_logprobs_dict_allclose(multi_step_logprobs,
single_step_logprobs)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_correct_output():
"""
In this test we verify that the MultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed = 100
model_name = os.path.join(models_path_prefix, 'JackFram/llama-68m')
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: list[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache(multi_step_worker.cache_engine)
all_seq_ids = {i for i in range(batch_size)}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
for index, output in enumerate(multi_step_output[-1].outputs):
assert (continuations[index][-1] == output.samples[0].output_token)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_incorrect_output():
"""
Tests the MultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed = 100
model_name = os.path.join(models_path_prefix, 'JackFram/llama-68m')
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: list[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
num_mismatch = 0
for index, output in enumerate(multi_step_output[-1].outputs):
if (index % 2) != 0:
assert (continuations[index][-1] == output.samples[0].output_token)
elif (continuations[index][-1] != output.samples[0].output_token):
num_mismatch += 1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert (num_mismatch > 0)
# @torch.inference_mode()
# @pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
# # The choice of backends forces the multi_step_worker to choose between
# # the vanilla model_runner and TP1DraftModelRunner and that we can test
# # both code paths.
# @pytest.mark.parametrize('attn_backend',
# [_Backend.XFORMERS, _Backend.FLASH_ATTN])
# def test_multi_step_correct_kvcache(num_steps, attn_backend):
# """Verify that the KV cache of the draft model
# is correctly updated for sequences with bonus token.
# """
# seed = 100
# model_name = "JackFram/llama-68m"
# block_size = 16
# num_gpu_blocks = 2048 // block_size
# batch_size = 1
# with global_force_attn_backend_context_manager(attn_backend):
# dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32'
# multi_step_worker = create_worker(MultiStepWorker,
# model_name,
# block_size,
# num_gpu_blocks,
# seed,
# model_runner_cls=TP1DraftModelRunner,
# dtype=dtype)
# multi_step_worker.set_include_gpu_probs_tensor()
# worker = create_worker(Worker,
# model_name,
# block_size,
# num_gpu_blocks,
# seed,
# dtype=dtype)
# prompts = [[0] for _ in range(batch_size)]
# # Already generate two tokens for the sequence
# # so that we can simulate the bonus token case
# multi_step_continuations = [[
# random.randint(0, 1000),
# random.randint(0, 1000)
# ] for _ in prompts]
# final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
# seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
# seq_group_metadata_list = create_seq_group_metadata_from_prompts(
# prompts,
# num_gpu_blocks,
# block_size,
# continuations=multi_step_continuations,
# final_prompt_lens=final_prompt_lens)
# # Run multi-step.
# zero_kv_cache(multi_step_worker.cache_engine)
# multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
# seq_group_metadata_list=seq_group_metadata_list),
# sample_len=num_steps,
# seq_ids_with_bonus_token_in_last_step=
# seq_ids_with_bonus_token_in_last_step)
# # Run single-step repeatedly.
# zero_kv_cache(worker.cache_engine)
# # Generate the kv cache for the bonus token first
# single_step_continuations = [c[:1] for c in multi_step_continuations]
# seq_group_metadata_list = create_seq_group_metadata_from_prompts(
# prompts,
# num_gpu_blocks,
# block_size,
# continuations=single_step_continuations,
# final_prompt_lens=final_prompt_lens)
# single_step_output = worker.execute_model(
# execute_model_req=ExecuteModelRequest(
# seq_group_metadata_list=seq_group_metadata_list))
# for _ in range(num_steps):
# seq_group_metadata_list = create_seq_group_metadata_from_prompts(
# prompts,
# num_gpu_blocks,
# block_size,
# continuations=multi_step_continuations,
# final_prompt_lens=final_prompt_lens)
# single_step_output = worker.execute_model(
# execute_model_req=ExecuteModelRequest(
# seq_group_metadata_list=seq_group_metadata_list))
# for i, seq_group_output in enumerate(single_step_output[-1]):
# multi_step_continuations[i].append(
# seq_group_output.samples[0].output_token)
# # Verify that the KV cache of the single-step and
# # multi-step workers are the same.
# single_step_gpu_cache = worker.cache_engine[0].gpu_cache
# multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
# num_layers = len(single_step_gpu_cache)
# allclose = lambda a, b: torch.allclose(
# a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2)
# for i in range(num_layers):
# assert allclose(single_step_gpu_cache[i][0],
# multi_step_gpu_cache[i][0])
# assert allclose(single_step_gpu_cache[i][1],
# multi_step_gpu_cache[i][1])
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=2048,
)
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(low=0,
high=vocab_size,
size=(batch_size, ),
device=device,
dtype=torch.long),
) for _ in range(k)
], True
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_no_speculations():
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
prompt_len = 10
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=prompt_len + k - 1,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prompt_len=prompt_len)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_mixed_k():
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
small_prompt_len = 5
long_prompt_len = 10
prev_output_token_len = 20
expected_num_proposal_seqs = 6
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
prompt_len = [
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
] + [long_prompt_len
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
)
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(
low=0,
high=vocab_size,
size=(expected_num_proposal_seqs, ),
device=device,
dtype=torch.long),
) for _ in range(k)
], True
seq_group_metadata_list, _, _ = create_batch(
batch_size,
k,
prompt_len=prompt_len,
prev_output_token_len=prev_output_token_len,
)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
@torch.inference_mode()
def test_use_draft_model_runner_advance_step():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed = 100
model_name = os.path.join(models_path_prefix, 'JackFram/llama-68m')
k = 5
batch_size = 32
block_size = 32
num_gpu_blocks = 2048 // block_size
worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret = "artificial stop"
worker.model_runner._gpu_advance_step = MagicMock()
worker.model_runner._gpu_advance_step.side_effect = ValueError(
exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
# Fallback (should not call) when num_steps=1.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=1)
worker.execute_model(execute_model_req=execute_model_req)
# Expect exception if _gpu_advance_step is called.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1
@torch.inference_mode()
def test_expand_execute_model_request_sync_with_expand_hidden_states():
"""
In this test we verify that the logic for expanding the
seq_group_metadata_list remains in sync with the expansion logic of
the HiddenStates in _expand_execute_model_request.
"""
k = 5
batch_size = 16
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_request = ExecuteModelRequest(
seq_group_metadata_list,
previous_hidden_states=HiddenStates(
torch.arange(batch_size), seq_group_metadata_list,
torch.arange(batch_size, 2 * batch_size)))
expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\
_expand_execute_model_request(execute_model_request,
seq_with_bonus_token_in_last_step)
all_seq_ids = torch.tensor(
get_all_seq_ids(
expanded_execute_model_request.seq_group_metadata_list))
ref_expanded_hidden_states = all_seq_ids + batch_size
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
assert (ref_expanded_hidden_states == expanded_execute_model_request.
previous_hidden_states.hidden_states).all().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment