Commit 705f6a35 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1

parents af837396 4cf256ae
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'
def do_sample(engine):
prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def do_sample(engine):
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)
expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output
...@@ -5,16 +5,12 @@ Run `pytest tests/quantization/test_bitsandbytes.py`. ...@@ -5,16 +5,12 @@ Run `pytest tests/quantization/test_bitsandbytes.py`.
import pytest import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams from vllm import SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
@pytest.mark.skipif( reason='bitsandbytes is not supported on this GPU type.')
capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(),
reason='bitsandbytes is not supported on this GPU type.')
def test_load_bnb_model(vllm_runner) -> None: def test_load_bnb_model(vllm_runner) -> None:
with vllm_runner('huggyllama/llama-7b', with vllm_runner('huggyllama/llama-7b',
quantization='bitsandbytes', quantization='bitsandbytes',
......
...@@ -3,16 +3,25 @@ ...@@ -3,16 +3,25 @@
Run `pytest tests/quantization/test_compressed_tensors.py`. Run `pytest tests/quantization/test_compressed_tensors.py`.
""" """
import pytest
import torch import torch
from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken, CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8StaticTensor) CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
def test_compressed_tensors_w8a8_static_setup(vllm_runner): QuantizationType)
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
QuantizationType.INT, 2560),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
QuantizationType.INT, 2560),
])
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0 = model_args
with vllm_runner(model_path, enforce_eager=True) as llm: with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0] layer = model.model.layers[0]
...@@ -28,35 +37,116 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner): ...@@ -28,35 +37,116 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
CompressedTensorsLinearMethod) CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method, assert isinstance(down_proj.quant_method,
CompressedTensorsLinearMethod) CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.is_static_input_scheme
expected_type = torch.int8
assert qkv_proj.weight.dtype is expected_type
assert o_proj.weight.dtype is expected_type
assert gate_up_proj.weight.dtype is expected_type
if qkv_proj.scheme.strategy == "tensor":
# Make sure it is a channelwise buffer
# After running process_weights_after_loading
assert len(qkv_proj.weight_scale.shape) == 2
assert qkv_proj.weight_scale.shape[0] == shape_0
assert qkv_proj.weight_scale.shape[1] == 1
assert qkv_proj.weight_scale.dtype is torch.float32
assert qkv_proj.input_scale.dtype is torch.float32
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
def test_compressed_tensors_no_enforce_eager(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
with vllm_runner(model_path) as llm:
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
])
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert not qkv_proj.scheme.is_static_input_scheme
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8 assert qkv_proj.weight.dtype is torch.int8
assert o_proj.weight.dtype is torch.int8
assert gate_up_proj.weight.dtype is torch.int8
assert qkv_proj.weight_scale.shard_splitter is not None output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert qkv_proj.weight_scale.logical_widths is not None assert output
assert qkv_proj.input_scale.dtype is torch.float32
def test_compressed_tensors_no_enforce_eager(vllm_runner): @pytest.mark.parametrize(
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2" "wNa16_args",
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.group_size == (-1 if group is None else group)
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == pack_factor
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with vllm_runner(model_path) as llm: with vllm_runner(model_path) as llm:
sampling_params = SamplingParams() model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
output = llm.generate("Hello world!", sampling_params=sampling_params) layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
assert qkv_proj.weight_packed.dtype is torch.int32
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output assert output
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner): def test_compressed_tensors_fp8(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2" model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
with vllm_runner(model_path, enforce_eager=True, with vllm_runner(model_path) as llm:
dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0] layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert qkv_proj.weight.dtype is torch.int8 assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.input_scale.dtype is torch.float32
assert qkv_proj.weight_scale.dtype is torch.float32
# should be scalars after processing
assert len(qkv_proj.input_scale.shape) == 0
assert len(qkv_proj.weight_scale.shape) == 0
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
...@@ -4,6 +4,7 @@ Run `pytest tests/quantization/test_configs.py --forked`. ...@@ -4,6 +4,7 @@ Run `pytest tests/quantization/test_configs.py --forked`.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
import pytest import pytest
...@@ -51,7 +52,7 @@ MODEL_ARG_EXPTYPES = [ ...@@ -51,7 +52,7 @@ MODEL_ARG_EXPTYPES = [
@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES) @pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
def test_auto_gptq(model_arg_exptype: str) -> None: def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
model_path, quantization_arg, expected_type = model_arg_exptype model_path, quantization_arg, expected_type = model_arg_exptype
try: try:
......
...@@ -5,20 +5,88 @@ Run `pytest tests/quantization/test_fp8.py --forked`. ...@@ -5,20 +5,88 @@ Run `pytest tests/quantization/test_fp8.py --forked`.
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
capability = torch.cuda.get_device_capability() MODELS = [
capability = capability[0] * 10 + capability[1] "neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
"nm-testing/Phi-3-mini-128k-instruct-FP8",
]
@pytest.mark.skipif( @pytest.mark.skipif(not is_quant_method_supported("fp8"),
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(), reason="FP8 is not supported on this GPU type.")
reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS)
def test_model_load_and_run(vllm_runner, model: str):
with vllm_runner(model) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)
print(outputs[0][1])
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
def test_load_fp16_model(vllm_runner) -> None: def test_load_fp16_model(vllm_runner) -> None:
with vllm_runner("facebook/opt-125m", quantization="fp8") as llm: with vllm_runner("facebook/opt-125m", quantization="fp8") as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
fc1 = model.model.decoder.layers[0].fc1 fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod) assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability >= 89:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:
def quantize_ref(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def per_tensor_dequantize(tensor, inv_scale, dtype):
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
# Note that we use a shape % 4 != 0 to cover edge cases,
# because scaled_fp8_quant is vectorized by 4.
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
# Dynamic quantization
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
# Reference dynamic quantizaton
y = quantize_ref(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
# Static quantization
y, _ = ops.scaled_fp8_quant(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
# Padding
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
assert y.shape[0] == 17
assert torch.allclose(
ref_y,
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
dtype))
"""Tests whether gptq models with quantized lm_head can be loaded.
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
"""
from typing import Tuple
import pytest
import torch
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
PROMPT = "On the surface of Mars, we found"
MODELS_QUANT = [(
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]
@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
def test_lm_head(
vllm_runner,
model_lm_head_quant: Tuple[str, bool],
) -> None:
model, lm_head_quantized = model_lm_head_quant
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)
lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model.lm_head)
if lm_head_quantized:
assert isinstance(
lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
else:
assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod)
print(
vllm_model.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)[0][1])
del vllm_model
import torch
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
def is_quant_method_supported(quant_method: str) -> bool:
# Currently, all quantization methods require Nvidia or AMD GPUs
if not torch.cuda.is_available():
return False
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
return (capability >=
QUANTIZATION_METHODS[quant_method].get_min_capability())
from typing import List
import pytest import pytest
import torch import torch
...@@ -9,7 +11,8 @@ MODELS = ["facebook/opt-125m"] ...@@ -9,7 +11,8 @@ MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype",
["float"]) # needed for comparing logprobs with HF
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size @pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False]) @pytest.mark.parametrize("detokenize", [True, False])
...@@ -62,21 +65,22 @@ def test_get_prompt_logprobs( ...@@ -62,21 +65,22 @@ def test_get_prompt_logprobs(
for logprobs in result.outputs[0].logprobs: for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs assert len(logprobs) == num_top_logprobs
output_text = result.outputs[0].text output_text = result.outputs[0].text
output_string_from_most_likely_tokens = [] output_string_from_most_likely_tokens_lst: List[str] = []
for top_logprobs in result.outputs[0].logprobs: for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values())) top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append( output_string_from_most_likely_tokens_lst.append(
top_logprob.decoded_token) top_logprob.decoded_token)
if detokenize: if detokenize:
output_string_from_most_likely_tokens = "".join( output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens) output_string_from_most_likely_tokens_lst)
assert output_text == output_string_from_most_likely_tokens, ( assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position " "The output text from the top logprob for each token position "
"should be the same as the output text in the result.") "should be the same as the output text in the result.")
else: else:
assert output_text == '' assert output_text == ''
assert output_string_from_most_likely_tokens == [None] * max_tokens assert output_string_from_most_likely_tokens_lst == ([None] *
max_tokens)
# The first prompt logprob is always None # The first prompt logprob is always None
assert result.prompt_logprobs[0] is None assert result.prompt_logprobs[0] is None
......
...@@ -246,8 +246,8 @@ def test_rejection_sampling_approximates_target_distribution( ...@@ -246,8 +246,8 @@ def test_rejection_sampling_approximates_target_distribution(
draft_and_target_probs_equal) draft_and_target_probs_equal)
sample_sizes = [10, 100, 1_000, 10_000, 100_000] sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference = [] distance_wrt_reference: List[float] = []
distance_wrt_target = [] distance_wrt_target: List[float] = []
for num_samples in sample_sizes: for num_samples in sample_sizes:
(reference_vs_rejsample_dist, (reference_vs_rejsample_dist,
......
import itertools import itertools
import random import random
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
...@@ -49,8 +49,8 @@ def _do_sample( ...@@ -49,8 +49,8 @@ def _do_sample(
sampling_params: SamplingParams, sampling_params: SamplingParams,
device: str, device: str,
): ):
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens = [] seq_lens: List[int] = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -212,7 +212,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -212,7 +212,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
batch_size = random.randint(1, 128) batch_size = random.randint(1, 128)
expected_penalization = [] expected_penalization = []
sequence_metadata_list = [] sequence_metadata_list: List[SequenceGroupMetadata] = []
# 20% chance to generate seq group metadata list with all prompts # 20% chance to generate seq group metadata list with all prompts
is_prompt = random.random() < 0.2 is_prompt = random.random() < 0.2
while batch_size > 0: while batch_size > 0:
...@@ -232,8 +232,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -232,8 +232,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids) stop_token_ids=stop_token_ids)
seq_data = {} seq_data: Dict[int, SequenceData] = {}
seq_group_penalization = [] seq_group_penalization: List[bool] = []
for _ in range(num_seqs): for _ in range(num_seqs):
num_input = random.randint(1, 100) num_input = random.randint(1, 100)
num_generated = 0 if is_prompt else random.randint(1, 100) num_generated = 0 if is_prompt else random.randint(1, 100)
...@@ -392,17 +392,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -392,17 +392,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
else: else:
test_cases = [generate_test_case()] test_cases = [generate_test_case()]
def run_test_case(*, def run_test_case(*, expected_penalization: List[bool],
expected_penalization=None, seq_group_metadata_list: List[SequenceGroupMetadata]):
seq_group_metadata_list=None):
assert expected_penalization, \ assert expected_penalization, \
"Invalid test case, need expected_penalization" "Invalid test case, need expected_penalization"
assert seq_group_metadata_list, \ assert seq_group_metadata_list, \
"Invalid test case, need seq_group_metadata_list" "Invalid test case, need seq_group_metadata_list"
batch_size = 0 batch_size = 0
seq_lens = [] seq_lens: List[int] = []
sampling_params_per_row = [] sampling_params_per_row: List[SamplingParams] = []
for sgm in seq_group_metadata_list: for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params sampling_params = sgm.sampling_params
...@@ -472,15 +471,15 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -472,15 +471,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler = _prepare_test(batch_size) input_tensor, fake_logits, sampler = _prepare_test(batch_size)
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
expected_tokens: List[Optional[List[int]]] = [] expected_tokens: List[Optional[List[int]]] = []
seq_lens = [] seq_lens: List[int] = []
for i in range(batch_size): for i in range(batch_size):
expected: Optional[List[int]] = None expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3) sampling_type = random.randint(0, 3)
if sampling_type == 0: if sampling_type == 0:
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
expected = [torch.argmax(fake_logits[i], dim=-1).item()] expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
elif sampling_type in (1, 2): elif sampling_type in (1, 2):
n = random.randint(1, 10) n = random.randint(1, 10)
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -536,15 +535,18 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -536,15 +535,18 @@ def test_sampler_mixed(seed: int, device: str):
] ]
continue continue
expected_tokens_item = expected_tokens[i]
assert expected_tokens_item is not None
for n, nth_output in enumerate(sequence_output.samples): for n, nth_output in enumerate(sequence_output.samples):
if (metadata.sampling_params.temperature == 0 if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None): or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed # Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n] assert nth_output.output_token == expected_tokens_item[n]
else: else:
# For non-seeded random check that one of the high-logit # For non-seeded random check that one of the high-logit
# tokens were chosen # tokens were chosen
assert nth_output.output_token in expected_tokens[i] assert nth_output.output_token in expected_tokens_item
# Test batch # Test batch
test_sampling() test_sampling()
...@@ -585,11 +587,11 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -585,11 +587,11 @@ def test_sampler_top_k_top_p(seed: int, device: str):
generation_config = GenerationConfig(top_k=top_k, generation_config = GenerationConfig(top_k=top_k,
top_p=top_p, top_p=top_p,
do_sample=True) do_sample=True)
warpers = generation_model._get_logits_warper(generation_config) warpers = generation_model._get_logits_warper(generation_config, device)
assert len(warpers) == 2 # top_p and top_k assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens = [] seq_lens: List[int] = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -622,7 +624,79 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -622,7 +624,79 @@ def test_sampler_top_k_top_p(seed: int, device: str):
with patch("vllm.model_executor.layers.sampler._sample", mock_sample): with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata) sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
assert sample_probs is not None
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):
vocab_size = 8
def test_sampling_params(sampling_params: List[SamplingParams]):
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
for i in range(2):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
fake_logits = torch.full((2, vocab_size),
1e-2,
device=device,
dtype=torch.float16)
fake_logits[:, 5] = 1.1e-2
fake_logits[:, 1] = 1.2e-2
sampler = MockLogitsSampler(fake_logits)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)
generated_tokens = []
for output in sampler_output:
generated_tokens.append(output.samples[0].output_token)
return generated_tokens
# one configuration is greedy with repetition_penalty
sampling_params_rep = SamplingParams(
temperature=0.0,
repetition_penalty=2.0,
)
# other configuration is sampling w/o repetition_penalty
sampling_params_sample = SamplingParams(
temperature=1.0,
top_k=1,
seed=42,
)
tokens1 = test_sampling_params(
[sampling_params_rep, sampling_params_sample])
tokens2 = test_sampling_params(
[sampling_params_sample, sampling_params_rep])
assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0]
"""Tests for rejection sampling."""
import pytest
import torch
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.model_executor.utils import set_random_seed
CUDA_DEVICES = [f"cuda:{i}" for i in range(1)]
def get_zero_temperature_prob_dist(batch_size, k, vocab_size):
"""
Generates a fake temperature zero probability distribution.
Returns:
1. A fake temperature zero probability distribution of shape
[batch_size, k, vocab_size]
2. Tensor of shape [batch_size, k] containing the token ids
of the probability 1.0 tokens at each position.
"""
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
probs = torch.rand(batch_size, k, vocab_size)
_, zero_temperature_token_ids = torch.max(probs, dim=-1)
# set the probability of the tokens with ids in zero_temperature_token_ids
# to 1 and the rest to 0.
target_probs = torch.zeros_like(probs).scatter_(
-1, zero_temperature_token_ids.unsqueeze(-1), 1.0)
return target_probs, zero_temperature_token_ids
def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
token_ids_to_exclude: torch.Tensor):
"""
Returns a tensor of shape [batch_size, k] of fake draft token ids
drawn randomly from a vocab of size vocab_size. We however ensure
that token_ids from token_ids_to_exclude are excluded at the
corresponding positions.
"""
draft_token_ids = torch.empty(batch_size, k, dtype=torch.long)
for i in range(batch_size):
for j in range(k):
# Generate a random token ID excluding token_ids_to_exclude[i, j]
while True:
token_id = torch.randint(0, vocab_size, (1, )).item()
if token_id != token_ids_to_exclude[i, j]:
draft_token_ids[i, j] = token_id
break
return draft_token_ids
def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
disable_bonus_tokens, strict_mode)
@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str):
"""
Tests that the TypicalAcceptancSampler forward succeeds for
different combinations of k, vocab_size, batch_size and num devices.
"""
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str, device: str):
"""
Tests that we throw an exception of the token ids fall outside
the bound of the provided vocabulary.
"""
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
# Verify that appropriate exceptions are thrown for out
# of bound vocabs.
oob_token_ids = None
if which_token_ids == "bonus_token_ids":
oob_token_ids = bonus_token_ids
elif which_token_ids == "draft_token_ids":
oob_token_ids = draft_token_ids
else:
raise AssertionError()
if above_or_below_vocab_range == "above":
rogue_token_id = vocab_size + 1
elif above_or_below_vocab_range == "below":
rogue_token_id = -1
else:
raise AssertionError()
oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_uniform_target_distribution_accepts_all_tokens(
seed: int, disable_bonus_tokens: bool, device: str):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted. The test also ensures that the behavior
regarding bonus tokens is consistent with the `disable_bonus_tokens`
flag.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that.
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_temperature_zero_target_distribution(seed: int,
disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
This test verifies that when using a zero-temperature target probability
distribution, where only one token has a probability of 1.0, the
TypicalAcceptanceSampler correctly rejects all draft tokens that do not
match this probability. Additionally, it ensures that when all draft
tokens are rejected, the sampler falls back to greedy sampling to select a
single token from the target distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k, vocab_size)
# Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
# The target probaility distribution is a temperature zero distribution
# with zero entroy. Since our draft token ids don't match the probability
# 1.0 tokens in the target distribution we will reject all of them and
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == -1)
assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:,
0])
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
This test ensures that the TypicalAcceptanceSampler handles a mixed
target probability distribution correctly. Specifically, it uses a
zero-temperature distribution for some sequences and a uniform
distribution for others. The test verifies that:
- For sequences with a zero-temperature distribution, only the token
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
for sequences with a uniform distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
target_probs[[1, 3]] = uniform_probs
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# verify the shape of output_token_ids
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
# For sequences 0 and 2 verify that only 1 token is accepted
# which is the token with probability 1.0 in the target distribution
# at position 0.
assert torch.all(output_token_ids[[0, 2], 1:] == -1)
assert (torch.all(output_token_ids[[0, 2],
0] == zero_temperature_token_ids[[0, 2],
0]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# if disable_bonus_tokens is false then we also accept the bonus tokens.
assert torch.all(
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
if disable_bonus_tokens:
assert torch.all(output_token_ids[[1, 3], -1] == -1)
else:
assert torch.all(output_token_ids[[1, 3], -1] != -1)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
This test verifies that the TypicalAcceptanceSampler correctly accepts or
rejects draft tokens based on a zero-temperature target probability
distribution. Specifically, it ensures that:
- When all draft tokens match tokens with a probability of 1.0 in the
target distribution, all draft tokens are accepted.
- When only some draft tokens match tokens with a probability of 1.0 in
the target distribution, only those matching tokens are accepted, and the
rest are rejected.
"""
set_random_seed(seed)
k = 5
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
draft_token_ids = zero_temperature_token_ids
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the rest as -1
draft_token_ids_to_replace = get_draft_token_ids(
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
assert torch.all(output_token_ids[:, -3:] == -1)
@pytest.mark.parametrize("seed", list(range(1)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_set_non_default_posteriors(seed: int,
disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
thresholds and alpha values we can change the acceptance behavior of the
sampler.
"""
set_random_seed(seed)
k = 5
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# id has probability 1.0 and others have a very low probability of
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_probs[target_probs == 0] = 0.00001
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 1:-1] == -1)
# Change the posterior threshold values to 0.0 so that we will
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True,
disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=0.0,
posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
This test verifies that the `_replacement_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as replacements based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
"""
set_random_seed(seed)
k = 10
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(
(batch_size, k), dtype=torch.long)
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
dim=1)
actual_replacement_tokens = (
typical_acceptance_sampler._replacement_token_ids(target_probs))
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
import asyncio import asyncio
import time
from itertools import cycle from itertools import cycle
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -7,25 +6,21 @@ import pytest ...@@ -7,25 +6,21 @@ import pytest
import ray import ray
import torch import torch
from vllm.utils import is_hip
if (not is_hip()):
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit)
from vllm import LLM from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.multimodal import MultiModalData from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid from vllm.utils import Counter, random_uuid
from ...conftest import cleanup from ...conftest import cleanup
from ...utils import wait_for_gpu_memory_to_clear
class AsyncLLM: class AsyncLLM:
...@@ -97,7 +92,8 @@ class AsyncLLM: ...@@ -97,7 +92,8 @@ class AsyncLLM:
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None, multi_modal_data: Optional[MultiModalDataDict] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> List[RequestOutput]: ) -> List[RequestOutput]:
if prompts is None: if prompts is None:
...@@ -118,16 +114,17 @@ class AsyncLLM: ...@@ -118,16 +114,17 @@ class AsyncLLM:
raise ValueError("The lengths of prompts and " raise ValueError("The lengths of prompts and "
"sampling_params must be the same.") "sampling_params must be the same.")
async def get_output(prompt, sampling_param) -> str: async def get_output(prompt, sampling_param) -> RequestOutput:
request_id = random_uuid() request_id = random_uuid()
results_generator = self.llm_engine.generate( results_generator = self.llm_engine.generate(
prompt, sampling_param, request_id) prompt, sampling_param, request_id)
final_output = None final_output = None
async for request_output in results_generator: async for request_output in results_generator:
final_output = request_output final_output = request_output
assert final_output is not None
return final_output return final_output
outputs = [] outputs: List[RequestOutput] = []
try: try:
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
...@@ -208,8 +205,8 @@ def maybe_assert_ngram_worker(llm): ...@@ -208,8 +205,8 @@ def maybe_assert_ngram_worker(llm):
def get_output_from_llm_generator( def get_output_from_llm_generator(
llm_generator, prompts, llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]: sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = [] tokens: List[str] = []
token_ids = [] token_ids: List[List[int]] = []
for llm in llm_generator(): for llm in llm_generator():
maybe_assert_ngram_worker(llm) maybe_assert_ngram_worker(llm)
...@@ -290,38 +287,3 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, ...@@ -290,38 +287,3 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {baseline_token_ids=}') print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}') print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids assert baseline_token_ids == spec_token_ids
def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit()
start_time = time.time()
while True:
output = {}
output_raw = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'
print('gpu memory used (GB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')
dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
break
if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
time.sleep(5)
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""
import pytest
import torch
from vllm.utils import is_hip
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
pytest.skip("hip is not well-supported yet")
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs, test_llm_kwargs",
[
(
{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a
# tokenizer.
"model": "JackFram/llama-68m",
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
}),
({
"model": "ibm-granite/granite-3b-code-instruct",
}, {
"speculative_model":
"ibm-granite/granite-3b-code-instruct-accelerator",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
})
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
baseline_llm_generator,
batch_size: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
...@@ -5,16 +5,16 @@ tensor parallelism. ...@@ -5,16 +5,16 @@ tensor parallelism.
import pytest import pytest
import torch import torch
from vllm.utils import is_hip
from .conftest import run_greedy_equality_correctness_test from .conftest import run_greedy_equality_correctness_test
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
...@@ -22,7 +22,7 @@ from .conftest import run_greedy_equality_correctness_test ...@@ -22,7 +22,7 @@ from .conftest import run_greedy_equality_correctness_test
# Required for spec decode. # Required for spec decode.
"use_v2_block_manager": True, "use_v2_block_manager": True,
"tensor_parallel_size": 2, "tensor_parallel_size": 4,
# Use AsyncLLM engine, so that the engine runs in its own process. # Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner # Otherwise, since vLLM does not follow true SPMD, the test runner
...@@ -31,35 +31,30 @@ from .conftest import run_greedy_equality_correctness_test ...@@ -31,35 +31,30 @@ from .conftest import run_greedy_equality_correctness_test
# second run of the test to fail with internal NCCL error. # second run of the test to fail with internal NCCL error.
"use_async": True, "use_async": True,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}, },
]) ])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "test_llm_kwargs",
[ [
# Use smaller output len for fast test. #TODO(wooyeon): add spec_draft_dp=2 case
32, {
"speculative_draft_tensor_parallel_size": 1,
},
]) ])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator, def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
batch_size: int, output_len: int): baseline_llm_generator,
"""Verify greedy equality when tensor parallelism is used. batch_size: int):
"""Verify spec decode works well with smaller tp for draft models.
""" """
if is_hip():
pytest.skip("hip is not well-supported yet")
run_greedy_equality_correctness_test(baseline_llm_generator, run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator, test_llm_generator,
batch_size, batch_size,
max_output_len=output_len, max_output_len=32,
force_output_len=True) force_output_len=True)
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, Medusa would not break the
correctess for the target model outputs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
# main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
# OOM in CI pipeline, so using a smaller model.
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 5
# precision
PRECISION = "float32"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
if __name__ == "__main__":
import pytest
pytest.main([__file__])
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, MLPSpeculator would not break the
correctess for the target model outputs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
# main model
MAIN_MODEL = "ibm-granite/granite-3b-code-instruct"
# speculative model
SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator"
# max. number of speculative tokens: this corresponds to
# n_predict in the config.json of the speculator model.
MAX_SPEC_TOKENS = 5
# precision
PRECISION = "float32"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
...@@ -11,9 +11,15 @@ distribution matches the target model's output distribution (up to hardware ...@@ -11,9 +11,15 @@ distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0. equality. This gives us good coverage of temp=0.
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model). be prohibitively expensive to run with a real model). Similarly, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
NOTE: Speculative decoding's distribution equality requires that the measured NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the distributions of the target model and proposal model be deterministic given the
...@@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, ...@@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size, batch_size,
max_output_len=output_len, max_output_len=output_len,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
}
# Try a range of common k.
for k in [1, 2, 3]
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_typical_acceptance_sampling(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
from typing import List
import pytest import pytest
import torch import torch
...@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int): ...@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int):
device='cuda', device='cuda',
) )
expected_output = [ expected_output: List[List[int]] = [
[], [],
] ]
for i in range(proposal_token_ids.shape[0]): for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist()) expected_output.append(proposal_token_ids[:i + 1].tolist())
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
actual_output = [ actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
...@@ -88,10 +90,10 @@ def test_create_single_target_seq_group_metadata(k: int): ...@@ -88,10 +90,10 @@ def test_create_single_target_seq_group_metadata(k: int):
assert output.request_id == input_seq_group_metadata.request_id assert output.request_id == input_seq_group_metadata.request_id
assert len(output.seq_data) == 1 assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids( assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
) == prompt_tokens prompt_tokens)
assert output.seq_data[target_seq_id].get_output_token_ids( assert output.seq_data[target_seq_id].get_output_token_ids() == tuple(
) == prev_output_tokens + token_ids prev_output_tokens + token_ids)
assert len(output.block_tables) == 1 assert len(output.block_tables) == 1
assert output.block_tables[ assert output.block_tables[
......
...@@ -3,33 +3,35 @@ from unittest.mock import MagicMock, patch ...@@ -3,33 +3,35 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, mock_worker from .utils import create_batch, mock_worker
@pytest.mark.parametrize('queue_size', [4]) @pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1]) @pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
acceptance_sampler_method: str):
"""Verify that speculative tokens are disabled when the batch size """Verify that speculative tokens are disabled when the batch size
exceeds the threshold. exceeds the threshold.
""" """
disable_by_batch_size = 3 disable_by_batch_size = 3
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker, worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker, scorer_worker=target_worker,
rejection_sampler=rejection_sampler, spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector, metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size) disable_by_batch_size=disable_by_batch_size)
...@@ -68,14 +70,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): ...@@ -68,14 +70,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
if queue_size < disable_by_batch_size: if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model. # Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest( proposer.get_spec_proposals(
seq_group_metadata_list=seq_group_metadata_list, execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k), ) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
else: else:
# Should not execute the draft model because spec decode is disabled # Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0. # for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert proposals.proposal_lens.tolist() == [0] * batch_size assert proposals.proposal_lens.tolist() == [0] * batch_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment