Commit e7c1b7f3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.4-dtk24.04.1'

parents 7462218e 04c62b93
import pytest
from .conftest import run_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# speculative model
"speculative_model": "JackFram/llama-160m",
# num speculative tokens
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("batch_size", [1, 8, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
20,
])
@pytest.mark.parametrize("seed", [None])
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
batch_size: int, temperature: float,
output_len: int):
"""Verify outputs are consistent across multiple runs with same seed
"""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)
# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=False,
force_output_len=True)
from typing import List
import pytest import pytest
import torch import torch
...@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int): ...@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int):
device='cuda', device='cuda',
) )
expected_output = [ expected_output: List[List[int]] = [
[], [],
] ]
for i in range(proposal_token_ids.shape[0]): for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist()) expected_output.append(proposal_token_ids[:i + 1].tolist())
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
actual_output = [ actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
...@@ -84,14 +86,15 @@ def test_create_single_target_seq_group_metadata(k: int): ...@@ -84,14 +86,15 @@ def test_create_single_target_seq_group_metadata(k: int):
input_seq_id, input_seq_id,
target_seq_id, target_seq_id,
token_ids, token_ids,
input_seq_group_metadata.sampling_params,
) )
assert output.request_id == input_seq_group_metadata.request_id assert output.request_id == input_seq_group_metadata.request_id
assert len(output.seq_data) == 1 assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids( assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
) == prompt_tokens prompt_tokens)
assert output.seq_data[target_seq_id].get_output_token_ids( assert output.seq_data[target_seq_id].get_output_token_ids() == tuple(
) == prev_output_tokens + token_ids prev_output_tokens + token_ids)
assert len(output.block_tables) == 1 assert len(output.block_tables) == 1
assert output.block_tables[ assert output.block_tables[
......
...@@ -3,33 +3,36 @@ from unittest.mock import MagicMock, patch ...@@ -3,33 +3,36 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, mock_worker from .utils import create_batch, mock_worker
@pytest.mark.parametrize('queue_size', [4]) @pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1]) @pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
acceptance_sampler_method: str):
"""Verify that speculative tokens are disabled when the batch size """Verify that speculative tokens are disabled when the batch size
exceeds the threshold. exceeds the threshold.
""" """
disable_by_batch_size = 3 disable_by_batch_size = 3
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker, worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker, scorer_worker=target_worker,
rejection_sampler=rejection_sampler, spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector, metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size) disable_by_batch_size=disable_by_batch_size)
...@@ -68,14 +71,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): ...@@ -68,14 +71,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
if queue_size < disable_by_batch_size: if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model. # Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest( proposer.get_spec_proposals(
seq_group_metadata_list=seq_group_metadata_list, execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k), ) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
else: else:
# Should not execute the draft model because spec decode is disabled # Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0. # for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert proposals.proposal_lens.tolist() == [0] * batch_size assert proposals.proposal_lens.tolist() == [0] * batch_size
...@@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector ...@@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
def test_initial_call_returns_none(): def test_initial_call_returns_none():
"""Expect first call to get metrics to return None. """Expect first call to get metrics to return None.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler) collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5) maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None assert maybe_metrics is None
...@@ -28,14 +28,14 @@ def test_initial_call_returns_none(): ...@@ -28,14 +28,14 @@ def test_initial_call_returns_none():
def test_second_call_returns_metrics(): def test_second_call_returns_metrics():
"""Expect second call to not return None. """Expect second call to not return None.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0 collect_interval_s = 5.0
timer = MagicMock() timer = MagicMock()
...@@ -43,7 +43,7 @@ def test_second_call_returns_metrics(): ...@@ -43,7 +43,7 @@ def test_second_call_returns_metrics():
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
] ]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
...@@ -56,16 +56,16 @@ def test_second_call_returns_metrics(): ...@@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
def test_nonzero_rank_noop(rank): def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics. """Verify nonzero ranks don't collect metrics.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler) collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=rank) collector.init_gpu_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5) _ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5) metrics = collector.maybe_collect_rejsample_metrics(k=5)
...@@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank): ...@@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
def test_noop_until_time(): def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes. """Verify metrics aren't collected until enough time passes.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0 collect_interval_s = 5.0
timer = MagicMock() timer = MagicMock()
...@@ -91,7 +91,7 @@ def test_noop_until_time(): ...@@ -91,7 +91,7 @@ def test_noop_until_time():
collect_interval_s + 0.1, collect_interval_s + 0.1 collect_interval_s + 0.1, collect_interval_s + 0.1
] ]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
...@@ -105,6 +105,49 @@ def test_noop_until_time(): ...@@ -105,6 +105,49 @@ def test_noop_until_time():
assert metrics is not None assert metrics is not None
def test_timer_is_reset():
"""Verify that the internal timer inside AsyncMetricsCollector
is reset after collection.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0,
collect_interval_s + 0.1,
collect_interval_s + 0.1,
collect_interval_s + 0.2,
collect_interval_s + 0.2,
2 * collect_interval_s + 0.1,
2 * collect_interval_s + 0.1,
]
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
@pytest.mark.parametrize("has_data", [True, False]) @pytest.mark.parametrize("has_data", [True, False])
def test_initial_metrics_has_correct_values(has_data: bool): def test_initial_metrics_has_correct_values(has_data: bool):
"""Test correctness of metrics data. """Test correctness of metrics data.
...@@ -122,14 +165,14 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -122,14 +165,14 @@ def test_initial_metrics_has_correct_values(has_data: bool):
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens( max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k) num_draft_tokens, k)
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = num_draft_tokens spec_decode_sampler.num_draft_tokens = num_draft_tokens
collect_interval_s = 5.0 collect_interval_s = 5.0
timer = MagicMock() timer = MagicMock()
...@@ -137,7 +180,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -137,7 +180,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
] ]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
......
import random import random
from typing import Dict, List
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -84,6 +86,7 @@ def test_same_output_for_single_step(): ...@@ -84,6 +86,7 @@ def test_same_output_for_single_step():
block_size, block_size,
num_gpu_blocks, num_gpu_blocks,
seed, seed,
model_runner_cls=TP1DraftModelRunner,
) )
worker = create_worker( worker = create_worker(
Worker, Worker,
...@@ -115,7 +118,8 @@ def test_same_output_for_single_step(): ...@@ -115,7 +118,8 @@ def test_same_output_for_single_step():
actual_output, _ = multi_step_worker.sampler_output( actual_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group), seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps) sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
assert len(actual_output) == num_steps assert len(actual_output) == num_steps
actual_output = actual_output[0] actual_output = actual_output[0]
...@@ -167,6 +171,7 @@ def test_same_output_for_multi_step(): ...@@ -167,6 +171,7 @@ def test_same_output_for_multi_step():
block_size, block_size,
num_gpu_blocks, num_gpu_blocks,
seed, seed,
model_runner_cls=TP1DraftModelRunner,
) )
worker = create_worker( worker = create_worker(
...@@ -206,11 +211,12 @@ def test_same_output_for_multi_step(): ...@@ -206,11 +211,12 @@ def test_same_output_for_multi_step():
multi_step_output, _ = multi_step_worker.sampler_output( multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list), seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps) sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
# Run single-step repeatedly. # Run single-step repeatedly.
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
single_step_output = [] single_step_output: List[SamplerOutput] = []
continuations = [[1] for _ in prompts] continuations = [[1] for _ in prompts]
set_random_seed(seed) set_random_seed(seed)
...@@ -232,11 +238,15 @@ def test_same_output_for_multi_step(): ...@@ -232,11 +238,15 @@ def test_same_output_for_multi_step():
continuations[i].append(seq_group_output.samples[0].output_token) continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison. # Get token ids and logprobs for comparison.
multi_step_output_logprobs = [[] for _ in prompts] multi_step_output_logprobs: List[List[Dict[int,
single_step_output_logprobs = [[] for _ in prompts] Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids = [[] for _ in prompts] single_step_output_logprobs: List[List[Dict[int,
single_step_output_token_ids = [[] for _ in prompts] Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts): for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output, for multi_step, single_step in zip(multi_step_output,
single_step_output): single_step_output):
...@@ -269,6 +279,203 @@ def test_same_output_for_multi_step(): ...@@ -269,6 +279,203 @@ def test_same_output_for_multi_step():
single_step_logprobs) single_step_logprobs)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_correct_output():
"""
In this test we verify that the MultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache(multi_step_worker.cache_engine)
all_seq_ids = {i for i in range(batch_size)}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
for index, output in enumerate(multi_step_output[-1].outputs):
assert (continuations[index][-1] == output.samples[0].output_token)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_incorrect_output():
"""
Tests the MultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
num_mismatch = 0
for index, output in enumerate(multi_step_output[-1].outputs):
if (index % 2) != 0:
assert (continuations[index][-1] == output.samples[0].output_token)
elif (continuations[index][-1] != output.samples[0].output_token):
num_mismatch += 1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert (num_mismatch > 0)
@torch.inference_mode() @torch.inference_mode()
def test_draft_proposals_full_speculation_len(): def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences """Verify Top1Proposer correctly handles case where all sequences
...@@ -310,7 +517,8 @@ def test_draft_proposals_full_speculation_len(): ...@@ -310,7 +517,8 @@ def test_draft_proposals_full_speculation_len():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -348,7 +556,8 @@ def test_draft_proposals_no_speculations(): ...@@ -348,7 +556,8 @@ def test_draft_proposals_no_speculations():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -420,7 +629,8 @@ def test_draft_proposals_mixed_k(): ...@@ -420,7 +629,8 @@ def test_draft_proposals_mixed_k():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -432,3 +642,51 @@ def test_draft_proposals_mixed_k(): ...@@ -432,3 +642,51 @@ def test_draft_proposals_mixed_k():
assert proposals.proposal_lens.tolist() == [ assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1) k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k] ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
@torch.inference_mode()
def test_use_draft_model_runner_advance_step():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed = 100
model_name = 'JackFram/llama-68m'
k = 5
batch_size = 32
block_size = 32
num_gpu_blocks = 2048 // block_size
worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret = "artificial stop"
worker.model_runner._gpu_advance_step = MagicMock()
worker.model_runner._gpu_advance_step.side_effect = ValueError(
exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
# Fallback (should not call) when num_steps=1.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=1)
worker.execute_model(execute_model_req=execute_model_req)
# Expect exception if _gpu_advance_step is called.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1
...@@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match(): ...@@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): ...@@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all(): ...@@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
......
import random import random
from collections import defaultdict
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Set
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
...@@ -15,23 +16,29 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker ...@@ -15,23 +16,29 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly) split_num_cache_blocks_evenly)
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, create_sampler_output_list, mock_worker from .utils import create_batch, create_sampler_output_list, mock_worker
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_draft_model(k: int, batch_size: int): def test_correctly_calls_draft_model(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the draft worker with correct """Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out. inputs. Everything else is mocked out.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop' exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
...@@ -52,15 +59,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): ...@@ -52,15 +59,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int): def test_correctly_calls_target_model(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct """Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out. inputs. Everything else is mocked out.
""" """
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False) target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
...@@ -68,8 +76,12 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -68,8 +76,12 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
vocab_size = 32_000 vocab_size = 32_000
...@@ -103,7 +115,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -103,7 +115,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
seen_contexts = [] seen_contexts: List[List[int]] = []
call_args_list = target_worker.execute_model.call_args_list call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
...@@ -116,7 +128,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -116,7 +128,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
for seq_data in seq_group_metadata.seq_data.values(): for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids()) seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts = [] expected_seen_contexts: List[List[int]] = []
for prompt, prev_generated, draft_tokens in zip( for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()): prompts, prev_output_tokens, proposal_token_ids.tolist()):
...@@ -132,8 +144,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -132,8 +144,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_rejection_sampler(k: int, batch_size: int): def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the rejection sampler with """Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out. correct inputs. Everything else is mocked out.
""" """
...@@ -143,16 +158,18 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -143,16 +158,18 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size=vocab_size, vocab_size=vocab_size,
use_spec=False) use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker,
metrics_collector) target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
proposal_token_ids = torch.randint(low=0, proposal_token_ids = torch.randint(low=0,
...@@ -198,15 +215,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -198,15 +215,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artificial stop' exception_secret = 'artificial stop'
rejection_sampler.side_effect = ValueError(exception_secret)
spec_decode_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=ExecuteModelRequest( worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
assert len(rejection_sampler.call_args_list) == 1 assert len(spec_decode_sampler.call_args_list) == 1
_, kwargs = rejection_sampler.call_args_list[0] _, kwargs = spec_decode_sampler.call_args_list[0]
actual = SimpleNamespace(**kwargs) actual = SimpleNamespace(**kwargs)
assert torch.equal(actual.bonus_token_ids, assert torch.equal(actual.bonus_token_ids,
...@@ -220,8 +238,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -220,8 +238,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_formats_output(k: int, batch_size: int): def test_correctly_formats_output(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker formats sampler output correctly. """Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out. Everything else is mocked out.
""" """
...@@ -231,16 +252,17 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -231,16 +252,17 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size=vocab_size, vocab_size=vocab_size,
use_spec=False) use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker,
metrics_collector) target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
proposal_token_ids = torch.randint(low=0, proposal_token_ids = torch.randint(low=0,
...@@ -285,24 +307,23 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -285,24 +307,23 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k + 1), size=(batch_size, k + 1),
dtype=torch.int64, dtype=torch.int64,
device='cuda') device='cuda')
for i in range(batch_size): for i in range(batch_size):
minimum_accepted_tokens = 1 minimum_accepted_tokens = 1
rejection_sampler_output[i][ spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1 -random.randint(minimum_accepted_tokens, k + 1):] = -1
rejection_sampler.return_value = rejection_sampler_output spec_decode_sampler.return_value = spec_decode_sampler_output
output = worker.execute_model(execute_model_req=ExecuteModelRequest( output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
expected_output = create_sampler_output_list( expected_output = create_sampler_output_list(
token_ids=rejection_sampler_output.transpose(0, 1), token_ids=spec_decode_sampler_output.transpose(0, 1),
probs=[None for _ in range(k + 1)], probs=[None for _ in range(k + 1)],
logprobs=[None for _ in range(k + 1)]) logprobs=[None for _ in range(k + 1)])
...@@ -310,8 +331,14 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -310,8 +331,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
next(iter(seq_group_metadata.seq_data.keys())) next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list for seq_group_metadata in seq_group_metadata_list
] ]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} seq_id: []
for seq_id in seq_ids
}
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
for step in output: for step in output:
for seq_group in step: for seq_group in step:
...@@ -343,8 +370,11 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -343,8 +370,11 @@ def test_correctly_formats_output(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2]) @pytest.mark.parametrize('k', [1, 2])
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('returns_metrics', [True, False]) @pytest.mark.parametrize('returns_metrics', [True, False])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker collects metrics. """Verify SpecDecodeWorker collects metrics.
""" """
vocab_size = 32_000 vocab_size = 32_000
...@@ -353,16 +383,18 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -353,16 +383,18 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size=vocab_size, vocab_size=vocab_size,
use_spec=False) use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker,
metrics_collector) target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
proposal_token_ids = torch.randint(low=0, proposal_token_ids = torch.randint(low=0,
...@@ -407,17 +439,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -407,17 +439,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k + 1), size=(batch_size, k + 1),
dtype=torch.int64, dtype=torch.int64,
device='cuda') device='cuda')
for i in range(batch_size): for i in range(batch_size):
minimum_accepted_tokens = 1 minimum_accepted_tokens = 1
rejection_sampler_output[i][ spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1 -random.randint(minimum_accepted_tokens, k + 1):] = -1
spec_decode_sampler.return_value = spec_decode_sampler_output
rejection_sampler.return_value = rejection_sampler_output
mock_rejsample_metrics = MagicMock( mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None spec=SpecDecodeWorkerMetrics) if returns_metrics else None
...@@ -438,26 +469,35 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -438,26 +469,35 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
@pytest.mark.parametrize('k', [0]) @pytest.mark.parametrize('k', [0])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_k_equals_zero(k: int, batch_size: int): def test_k_equals_zero(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers """Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill. when k is zero. This happens during prefill.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
...@@ -468,9 +508,10 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -468,9 +508,10 @@ def test_k_equals_zero(k: int, batch_size: int):
out = worker.execute_model(execute_model_req=execute_model_req) out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[ assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req) draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req)
...@@ -478,27 +519,36 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -478,27 +519,36 @@ def test_k_equals_zero(k: int, batch_size: int):
@pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('k', [0, 5])
@pytest.mark.parametrize('batch_size', [0]) @pytest.mark.parametrize('batch_size', [0])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_empty_input_batch(k: int, batch_size: int): def test_empty_input_batch(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers """Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch. to the workers information without scheduling a batch.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
...@@ -509,28 +559,34 @@ def test_empty_input_batch(k: int, batch_size: int): ...@@ -509,28 +559,34 @@ def test_empty_input_batch(k: int, batch_size: int):
out = worker.execute_model(execute_model_req=execute_model_req) out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[ assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req) draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_init_device(): def test_init_device(acceptance_sampler_method: str):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization. well as other GPU initialization.
""" """
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False) target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector,
)
worker.init_device() worker.init_device()
draft_worker.init_device.assert_called_once() draft_worker.init_device.assert_called_once()
...@@ -538,22 +594,25 @@ def test_init_device(): ...@@ -538,22 +594,25 @@ def test_init_device():
target_worker.init_device.assert_called_once() target_worker.init_device.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once() metrics_collector.init_gpu_tensors.assert_called_once()
rejection_sampler.init_gpu_tensors.assert_called_once() spec_decode_sampler.init_gpu_tensors.assert_called_once()
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_initialize_cache(): def test_initialize_cache(acceptance_sampler_method):
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers. workers.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(proposer_worker=draft_worker,
metrics_collector) scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector)
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs) worker.initialize_cache(**kwargs)
...@@ -566,19 +625,20 @@ def test_initialize_cache(): ...@@ -566,19 +625,20 @@ def test_initialize_cache():
@pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_determine_num_available_blocks(available_gpu_blocks: int, def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int, available_cpu_blocks: int,
target_cache_block_size_bytes: int, target_cache_block_size_bytes: int,
draft_kv_size_bytes: int): draft_kv_size_bytes: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks. """Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker. split the blocks between proposer and scorer worker.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.determine_num_available_blocks.return_value = ( target_worker.determine_num_available_blocks.return_value = (
...@@ -587,8 +647,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int, ...@@ -587,8 +647,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
target_cache_block_size_bytes) target_cache_block_size_bytes)
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
...@@ -618,3 +679,142 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int, ...@@ -618,3 +679,142 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
assert (num_blocks * target_cache_block_size_bytes) + ( assert (num_blocks * target_cache_block_size_bytes) + (
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks * num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
target_cache_block_size_bytes) target_cache_block_size_bytes)
@torch.inference_mode()
def test_populate_seq_ids_with_bonus_tokens():
"""
Verify that a call to _create_output_sampler_list correctly updates
seq_with_bonus_token_in_last_step.
seq_with_bonus_token_in_last_step is an internal data structure in
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
tokens by the target model in their last forward pass. This state is
maintained only for models relying on the KV cache, such as those using
the MultiStepWorker.
"""
batch_size = 10
k = 5
vocab_size = 10000
num_sequences_with_bonus_tokens = 5
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
target_worker.device = 'cuda'
set_random_seed(1)
draft_worker = mock_worker(cls=MultiStepWorker)
draft_worker.device = 'cuda'
# The sequence_ids attached to each sequence in the batch.
# The sequence at index i has seq_id assigned_seq_ids[i]
assigned_seq_ids = list(range(batch_size))
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
seq_ids=assigned_seq_ids,
prev_output_token_len=10)
target_token_logprobs = torch.rand(batch_size, (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
accepted_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, (k + 1)),
dtype=torch.int64,
device='cuda')
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
for seq_group_metadata in seq_group_metadata_list:
for seq_id in seq_group_metadata.seq_data:
expected_request_id_seq_ids_mapping[
seq_group_metadata.request_id].add(seq_id)
# Generate a random sample of sequence indexes with bonus tokens
seq_indexes_with_bonus_tokens = random.sample(
range(batch_size), num_sequences_with_bonus_tokens)
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
mask[seq_indexes_with_bonus_tokens] = False
# Set the last token ID to -1 for all indices not in
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
# those indices.
accepted_token_ids[mask, -1:] = -1
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
# the range [0, batch_size + num_extra_sequence_ids).
num_extra_sequence_ids = 10
worker._seq_with_bonus_token_in_last_step = set(
range(batch_size + num_extra_sequence_ids))
worker._create_output_sampler_list(
seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs,
k=k,
stage_times=(0, 0, 0))
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# batch are retained.
# 2. Of the sequence IDs present in the current batch, only those with a
# bonus token are retained in _seq_with_bonus_token_in_last_step.
# Sequence IDs that are present in the current batch but do not have
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
expected_seq_ids_with_bonus_tokens = \
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
additional_sequence_ids = \
set(range(batch_size, batch_size + num_extra_sequence_ids))
assert worker._seq_with_bonus_token_in_last_step == \
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
assert worker._request_id_seq_id_mapping == \
expected_request_id_seq_ids_mapping
@torch.inference_mode()
def test_handle_finished_requests():
"""
Test to verify that finished request IDs are appropriately processed to
update the internal state of the SpecDecodeWorker.
This test initializes the SpecDecodeWorker with mock data, marks certain
requests as finished, and ensures that the corresponding sequence IDs are
correctly removed from the internal mappings.
"""
batch_size = 32
k = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker,
mock_spec_decode_sampler("rejection_sampler"),
metrics_collector)
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
# request ids and corresponding sequence ids.
worker._request_id_seq_id_mapping = \
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
'request-3': {8,9}, 'request-4': {10,11}}
# Initialize seq_with_bonus_token_in_last_step with a few fake
# sequence ids.
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
# Mark requests with ids request-1 and request-3 as finished.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
finished_requests_ids=['request-1', 'request-3'])
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# Verify that request-1 and request-3 are removed from
# request_id_seq_id_mapping
assert worker._request_id_seq_id_mapping == \
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
# Verify that all sequence ids corresponding to 'request-1'
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert worker._seq_with_bonus_token_in_last_step == \
{4,5,10}
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import torch
from vllm.sequence import SequenceGroupMetadata from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import split_batch_by_proposal_len
def test_get_all_seq_ids(): def test_get_all_seq_ids():
...@@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): ...@@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
assert filtered_groups == [] assert filtered_groups == []
assert indices == [] assert indices == []
def mock_spec_decode_sampler(acceptance_sampler_method):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if acceptance_sampler_method == "rejection_sampler":
sampler = MagicMock(spec=RejectionSampler)
sampler.token_id_dtype = torch.int64
return sampler
elif acceptance_sampler_method == "typical_acceptance_sampler":
sampler = MagicMock(spec=TypicalAcceptanceSampler)
sampler.token_id_dtype = torch.int64
return sampler
else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
from itertools import count from itertools import count
from typing import Dict, Iterable, List, Optional, Union from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
import torch import torch
...@@ -12,8 +14,11 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, ...@@ -12,8 +14,11 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput) SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker)
def round_up_to_next_block(seq_len: int, block_size: int) -> int: def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size return (seq_len + block_size - 1) // block_size
...@@ -49,20 +54,21 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]): ...@@ -49,20 +54,21 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
return new_execute_model return new_execute_model
def zero_kv_cache(cache_engine: CacheEngine): def zero_kv_cache(cache_engine: List[CacheEngine]):
assert cache_engine.gpu_cache assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine.gpu_cache: for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_() key_blocks.zero_()
value_blocks.zero_() value_blocks.zero_()
def create_worker(cls: type, def create_worker(cls: Callable[..., T],
model_name: str, model_name: str,
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
seed: int, seed: int,
is_driver_worker: bool = True, is_driver_worker: bool = True,
enforce_eager: bool = True): enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None) -> T:
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
seed=seed, seed=seed,
...@@ -85,6 +91,7 @@ def create_worker(cls: type, ...@@ -85,6 +91,7 @@ def create_worker(cls: type,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
) )
worker.init_device() worker.init_device()
...@@ -159,8 +166,8 @@ def assert_logprobs_dict_allclose( ...@@ -159,8 +166,8 @@ def assert_logprobs_dict_allclose(
def create_sampler_output_list( def create_sampler_output_list(
token_ids: torch.Tensor, token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]], probs: GenericSequence[Optional[torch.Tensor]],
logprobs: Iterable[Optional[torch.Tensor]], logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist() token_ids_by_step = token_ids.tolist()
......
import contextlib
import functools
import gc
import pytest
import ray
import torch
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@pytest.fixture(autouse=True)
def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
ray.shutdown()
gc.collect()
torch.cuda.empty_cache()
def retry_until_skip(n):
def decorator_retry(func):
@functools.wraps(func)
def wrapper_retry(*args, **kwargs):
for i in range(n):
try:
return func(*args, **kwargs)
except AssertionError:
gc.collect()
torch.cuda.empty_cache()
if i == n - 1:
pytest.skip("Skipping test after attempts..")
return wrapper_retry
return decorator_retry
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config
import gc
import json import json
import os import os
import pathlib import pathlib
...@@ -6,7 +7,6 @@ from unittest.mock import MagicMock, patch ...@@ -6,7 +7,6 @@ from unittest.mock import MagicMock, patch
import openai import openai
import pytest import pytest
import ray
import torch import torch
from tensorizer import EncryptionParams from tensorizer import EncryptionParams
...@@ -21,13 +21,13 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, ...@@ -21,13 +21,13 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
serialize_vllm_model, serialize_vllm_model,
tensorize_vllm_model) tensorize_vllm_model)
from ..conftest import VllmRunner, cleanup from ..conftest import VllmRunner
from ..utils import RemoteOpenAIServer from ..utils import RemoteOpenAIServer
from .conftest import retry_until_skip
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
...@@ -49,14 +49,16 @@ def is_curl_installed(): ...@@ -49,14 +49,16 @@ def is_curl_installed():
except (subprocess.CalledProcessError, FileNotFoundError): except (subprocess.CalledProcessError, FileNotFoundError):
return False return False
def get_torch_model(vllm_runner: VllmRunner): def get_torch_model(vllm_runner: VllmRunner):
return vllm_runner \ return vllm_runner \
.model \ .model \
.llm_engine \ .llm_engine \
.model_executor \ .model_executor \
.driver_worker \ .driver_worker \
.model_runner \ .model_runner \
.model .model
def write_keyfile(keyfile_path: str): def write_keyfile(keyfile_path: str):
encryption_params = EncryptionParams.random() encryption_params = EncryptionParams.random()
...@@ -64,11 +66,6 @@ def write_keyfile(keyfile_path: str): ...@@ -64,11 +66,6 @@ def write_keyfile(keyfile_path: str):
with open(keyfile_path, 'wb') as f: with open(keyfile_path, 'wb') as f:
f.write(encryption_params.key) f.write(encryption_params.key)
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent') @patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config): def test_load_with_tensorizer(mock_agent, tensorizer_config):
...@@ -91,14 +88,15 @@ def test_can_deserialize_s3(vllm_runner): ...@@ -91,14 +88,15 @@ def test_can_deserialize_s3(vllm_runner):
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path, tensorizer_uri=tensorized_path,
num_readers=1, num_readers=1,
s3_endpoint="object.ord1.coreweave.com", s3_endpoint="object.ord1.coreweave.com",
)) as loaded_hf_model: )) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate(prompts,
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501 sampling_params)
# noqa: E501
assert deserialized_outputs assert deserialized_outputs
...@@ -118,18 +116,19 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( ...@@ -118,18 +116,19 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
encryption_keyfile=key_path encryption_keyfile=key_path
) )
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing) config_for_serializing)
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path) encryption_keyfile=key_path)
with vllm_runner( with vllm_runner(
model_ref, model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501 model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 deserialized_outputs = loaded_vllm_model.generate(prompts,
sampling_params)
# noqa: E501
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
...@@ -145,12 +144,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, ...@@ -145,12 +144,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
serializer.write_module(hf_model.model) serializer.write_module(hf_model.model)
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path, tensorizer_uri=model_path,
num_readers=1, num_readers=1,
)) as loaded_hf_model: )) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate_greedy( deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens) prompts, max_tokens=max_tokens)
...@@ -172,21 +170,21 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): ...@@ -172,21 +170,21 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path)) TensorizerConfig(tensorizer_uri=model_path))
with vllm_runner( with vllm_runner(
model_ref, model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path, tensorizer_uri=model_path,
num_readers=1, num_readers=1,
), ),
enable_lora=True, enable_lora=True,
max_loras=1, max_loras=1,
max_lora_rank=8, max_lora_rank=8,
max_cpu_loras=2, max_cpu_loras=2,
max_num_seqs=50, max_num_seqs=50,
max_model_len=1000, max_model_len=1000,
) as loaded_vllm_model: ) as loaded_vllm_model:
process_requests(loaded_vllm_model.model.llm_engine, test_prompts) process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
...@@ -194,10 +192,14 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): ...@@ -194,10 +192,14 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
def test_load_without_tensorizer_load_format(vllm_runner): def test_load_without_tensorizer_load_format(vllm_runner):
model = None
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_runner( model = vllm_runner(
model_ref, model_ref,
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
del model
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
...@@ -207,7 +209,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ...@@ -207,7 +209,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path)) TensorizerConfig(tensorizer_uri=model_path))
model_loader_extra_config = { model_loader_extra_config = {
"tensorizer_uri": str(model_path), "tensorizer_uri": str(model_path),
...@@ -215,34 +217,38 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ...@@ -215,34 +217,38 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
## Start OpenAI API server ## Start OpenAI API server
openai_args = [ openai_args = [
"--model", model_ref, "--dtype", "float16", "--load-format", "--dtype", "float16", "--load-format",
"tensorizer", "--model-loader-extra-config", "tensorizer", "--model-loader-extra-config",
json.dumps(model_loader_extra_config), json.dumps(model_loader_extra_config),
] ]
server = RemoteOpenAIServer(openai_args) with RemoteOpenAIServer(model_ref, openai_args) as server:
print("Server ready.") print("Server ready.")
client = server.get_client() client = server.get_client()
completion = client.completions.create(model=model_ref, completion = client.completions.create(model=model_ref,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
temperature=0.0) temperature=0.0)
assert completion.id is not None assert completion.id is not None
assert len(completion.choices) == 1 assert len(completion.choices) == 1
assert len(completion.choices[0].text) >= 5 assert len(completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length" assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage( assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11) completion_tokens=5, prompt_tokens=6, total_tokens=11)
def test_raise_value_error_on_invalid_load_format(vllm_runner): def test_raise_value_error_on_invalid_load_format(vllm_runner):
model = None
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_runner( model = vllm_runner(
model_ref, model_ref,
load_format="safetensors", load_format="safetensors",
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
del model
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
...@@ -264,23 +270,20 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner): ...@@ -264,23 +270,20 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
) )
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs") reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner, def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tmp_path): tmp_path):
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model # record outputs from un-sharded un-tensorized model
base_model = vllm_runner( with vllm_runner(
model_ref, model_ref,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
enforce_eager=True, enforce_eager=True,
) ) as base_model:
outputs = base_model.generate(prompts, sampling_params) outputs = base_model.generate(prompts, sampling_params)
base_model.model.llm_engine.model_executor.shutdown()
base_model.model.llm_engine.model_executor.shutdown()
del base_model
cleanup()
ray.shutdown()
# load model with two shards and serialize with encryption # load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors")) model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
...@@ -293,32 +296,34 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner, ...@@ -293,32 +296,34 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tensorize_vllm_model( tensorize_vllm_model(
engine_args=EngineArgs( engine_args=EngineArgs(
model=model_ref, model=model_ref,
tensor_parallel_size=2, tensor_parallel_size=2,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
enforce_eager=True, enforce_eager=True,
), ),
tensorizer_config=tensorizer_config, tensorizer_config=tensorizer_config,
) )
assert os.path.isfile(model_path % 0), "Serialization subprocess failed" assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
assert os.path.isfile(model_path % 1), "Serialization subprocess failed" assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
cleanup()
ray.shutdown()
loaded_vllm_model = vllm_runner(
model_ref,
tensor_parallel_size=2,
load_format="tensorizer",
disable_custom_all_reduce=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) with vllm_runner(
model_ref,
tensor_parallel_size=2,
load_format="tensorizer",
disable_custom_all_reduce=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts,
sampling_params)
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
@retry_until_skip(3)
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
gc.collect()
torch.cuda.empty_cache()
model_ref = "facebook/opt-125m" model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path)) config = TensorizerConfig(tensorizer_uri=str(model_path))
...@@ -330,8 +335,10 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): ...@@ -330,8 +335,10 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
assert is_vllm_tensorized(config) assert is_vllm_tensorized(config)
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=config) as loaded_vllm_model: model_loader_extra_config=config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 deserialized_outputs = loaded_vllm_model.generate(prompts,
sampling_params)
# noqa: E501
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
...@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, ...@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
max_input_length=None, max_input_length=None,
) )
hashes = [] hashes: List[List[List[int]]] = []
for prefix in prefixes: for prefix in prefixes:
for lora_int_id in concurrent_lora_int_ids: for lora_int_id in concurrent_lora_int_ids:
......
...@@ -104,8 +104,10 @@ def test_rope_customization(): ...@@ -104,8 +104,10 @@ def test_rope_customization():
dtype="float16", dtype="float16",
seed=0, seed=0,
) )
assert getattr(longchat_model_config.hf_config, "rope_scaling", # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
None) == LONGCHAT_ROPE_SCALING assert all(
longchat_model_config.hf_config.rope_scaling.get(key) == value
for key, value in LONGCHAT_ROPE_SCALING.items())
assert longchat_model_config.max_model_len == 16384 assert longchat_model_config.max_model_len == 16384
longchat_model_config = ModelConfig( longchat_model_config = ModelConfig(
......
import vllm
def test_embedded_commit_defined():
assert vllm.__commit__ != "COMMIT_HASH_PLACEHOLDER"
# 7 characters is the length of a short commit hash
assert len(vllm.__commit__) >= 7
...@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration(): ...@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration():
assert not logger.propagate assert not logger.propagate
handler = logger.handlers[0] handler = logger.handlers[0]
assert isinstance(handler, logging.StreamHandler)
assert handler.stream == sys.stdout assert handler.stream == sys.stdout
assert handler.level == logging.INFO assert handler.level == logging.INFO
......
...@@ -83,7 +83,7 @@ def test_logits_processors(seed: int, device: str): ...@@ -83,7 +83,7 @@ def test_logits_processors(seed: int, device: str):
device=device, device=device,
pin_memory=is_pin_memory_available()) pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor( logits_processor_output = logits_processor(
embedding=None, lm_head=None,
hidden_states=input_tensor, hidden_states=input_tensor,
sampling_metadata=sampling_metadata) sampling_metadata=sampling_metadata)
......
import pytest
import torch
from vllm.scalar_type import scalar_types
@pytest.mark.parametrize("type_tuple", (
(-8, 7, scalar_types.int4),
(0, 15, scalar_types.uint4),
(-8, 7, scalar_types.uint4b8),
(-128, 127, scalar_types.uint8b128),
(-28., 28., scalar_types.float6_e3m2f),
(torch.int8, scalar_types.int8),
(torch.uint8, scalar_types.uint8),
(torch.float8_e5m2, scalar_types.float8_e5m2),
(torch.float8_e4m3fn, scalar_types.float8_e4m3fn),
(torch.bfloat16, scalar_types.float16_e8m7),
(torch.float16, scalar_types.float16_e5m10),
),
ids=lambda x: str(x))
def test_scalar_type_min_max(type_tuple):
print(type_tuple)
if len(type_tuple) == 3:
min, max, t = type_tuple
else:
torch_type, t = type_tuple
if torch_type.is_floating_point:
min = torch.finfo(torch_type).min
max = torch.finfo(torch_type).max
else:
min = torch.iinfo(torch_type).min
max = torch.iinfo(torch_type).max
print(t, min, max, t.min(), t.max())
assert min == t.min()
assert max == t.max()
...@@ -7,7 +7,8 @@ from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol, ...@@ -7,7 +7,8 @@ from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
import pytest import pytest
from vllm.utils import deprecate_kwargs, get_open_port, merge_async_iterators from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
get_open_port, merge_async_iterators)
from .utils import error_on_warning from .utils import error_on_warning
...@@ -130,3 +131,61 @@ def test_get_open_port(): ...@@ -130,3 +131,61 @@ def test_get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
s3.bind(("localhost", get_open_port())) s3.bind(("localhost", get_open_port()))
os.environ.pop("VLLM_PORT") os.environ.pop("VLLM_PORT")
# Tests for FlexibleArgumentParser
@pytest.fixture
def parser():
parser = FlexibleArgumentParser()
parser.add_argument('--image-input-type',
choices=['pixel_values', 'image_features'])
parser.add_argument('--model-name')
parser.add_argument('--batch-size', type=int)
parser.add_argument('--enable-feature', action='store_true')
return parser
def test_underscore_to_dash(parser):
args = parser.parse_args(['--image_input_type', 'pixel_values'])
assert args.image_input_type == 'pixel_values'
def test_mixed_usage(parser):
args = parser.parse_args([
'--image_input_type', 'image_features', '--model-name',
'facebook/opt-125m'
])
assert args.image_input_type == 'image_features'
assert args.model_name == 'facebook/opt-125m'
def test_with_equals_sign(parser):
args = parser.parse_args(
['--image_input_type=pixel_values', '--model-name=facebook/opt-125m'])
assert args.image_input_type == 'pixel_values'
assert args.model_name == 'facebook/opt-125m'
def test_with_int_value(parser):
args = parser.parse_args(['--batch_size', '32'])
assert args.batch_size == 32
args = parser.parse_args(['--batch-size', '32'])
assert args.batch_size == 32
def test_with_bool_flag(parser):
args = parser.parse_args(['--enable_feature'])
assert args.enable_feature is True
args = parser.parse_args(['--enable-feature'])
assert args.enable_feature is True
def test_invalid_choice(parser):
with pytest.raises(SystemExit):
parser.parse_args(['--image_input_type', 'invalid_choice'])
def test_missing_required_argument(parser):
parser.add_argument('--required-arg', required=True)
with pytest.raises(SystemExit):
parser.parse_args([])
from typing import Dict, List from typing import Any, Dict, List, Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -139,6 +139,15 @@ def create_dummy_logprobs( ...@@ -139,6 +139,15 @@ def create_dummy_logprobs(
} for token_id in complete_sequence_token_ids] } for token_id in complete_sequence_token_ids]
def create_dummy_prompt_logprobs(
complete_sequence_token_ids: List[int]
) -> List[Optional[Dict[int, Any]]]:
# logprob for the first prompt token is None.
logprobs: List[Optional[Dict[int, Any]]] = [None]
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
return logprobs
@pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False]) @pytest.mark.parametrize("skip_special_tokens", [True, False])
...@@ -153,8 +162,8 @@ def test_decode_sequence_logprobs(complete_sequence: str, ...@@ -153,8 +162,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially. # Run sequentially.
seq = create_sequence() seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token = [] sequential_logprobs_text_chosen_token: List[str] = []
sequential_logprobs_text_other_token = [] sequential_logprobs_text_other_token: List[str] = []
for new_token, logprobs in zip(complete_sequence_token_ids, for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs): dummy_logprobs):
seq.append_token_id(new_token, logprobs) seq.append_token_id(new_token, logprobs)
...@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str, ...@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True]) def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
def test_decode_prompt_logprobs(complete_sequence: str, detokenizer: Detokenizer):
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes prompt logprobs correctly.""" """Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, sampling_params = SamplingParams(skip_special_tokens=True,
prompt_logprobs=1) prompt_logprobs=1)
# Run sequentially. # Run sequentially.
...@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str, ...@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
seqs=[seq], seqs=[seq],
sampling_params=sampling_params, sampling_params=sampling_params,
arrival_time=0.0) arrival_time=0.0)
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs) detokenizer.decode_prompt_logprobs_inplace(seq_group,
decoded_prompt_logprobs = dummy_logprobs dummy_logprobs,
position_offset=0)
# First logprob is None.
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
1:] # type: ignore
if skip_special_tokens: # decoded_prompt_logprobs doesn't contain the first token.
# Text for logprobs for the chosen token should be the same as the token_ids = complete_sequence_token_ids
# prompt text. Note that this will only be true if we skip tokenzier = detokenizer.get_tokenizer_for_seq(seq)
# special tokens. text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
assert complete_sequence == "".join([ text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
logprobs[token_id].decoded_token for token_id, logprobs in zip( text = text_full[len(text_first):]
complete_sequence_token_ids, decoded_prompt_logprobs)
]) # Text for logprobs for the chosen token should be the same as the
assert complete_sequence != "".join([ # prompt text. Note that the first logprob is None.
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip( assert text == "".join([
complete_sequence_token_ids, decoded_prompt_logprobs) logprobs[token_id].decoded_token
]) for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
assert text != "".join([
logprobs[token_id + 1].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_prompt_logprobs_chunked_prefill(
vllm_runner,
model,
chunked_prefill_token_size: int,
example_prompts,
):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size
with vllm_runner(model,
dtype="half",
max_logprobs=5,
gpu_memory_utilization=0.5,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:
vllm_sampling_params = SamplingParams(max_tokens=10,
logprobs=5,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
for idx, result in enumerate(vllm_results):
assert result.prompt_logprobs is not None
assert result.prompt_logprobs[0] is None
# Compared detokenized prompts ids to original prompt.
generated_string = ""
for (prompt_token,
prompt_logprobs) in zip(result.prompt_token_ids[1:],
result.prompt_logprobs[1:]):
# prompt_logprobs is a dict of the token_id: logprob
# We select the token_id corresponding to the actual prompt
# Decoded token in the detokenized string corresponding to this
# prompt token.
generated_string += prompt_logprobs[prompt_token].decoded_token
assert generated_string == example_prompts[idx], (
"Detokenized prompt logprobs do not match original prompt")
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
def test_get_llama3_eos_token():
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128009]
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment