Commit 705f6a35 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1

parents af837396 4cf256ae
...@@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector ...@@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
def test_initial_call_returns_none(): def test_initial_call_returns_none():
"""Expect first call to get metrics to return None. """Expect first call to get metrics to return None.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler) collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5) maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None assert maybe_metrics is None
...@@ -28,14 +28,14 @@ def test_initial_call_returns_none(): ...@@ -28,14 +28,14 @@ def test_initial_call_returns_none():
def test_second_call_returns_metrics(): def test_second_call_returns_metrics():
"""Expect second call to not return None. """Expect second call to not return None.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0 collect_interval_s = 5.0
timer = MagicMock() timer = MagicMock()
...@@ -43,7 +43,7 @@ def test_second_call_returns_metrics(): ...@@ -43,7 +43,7 @@ def test_second_call_returns_metrics():
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
] ]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
...@@ -56,16 +56,16 @@ def test_second_call_returns_metrics(): ...@@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
def test_nonzero_rank_noop(rank): def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics. """Verify nonzero ranks don't collect metrics.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler) collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=rank) collector.init_gpu_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5) _ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5) metrics = collector.maybe_collect_rejsample_metrics(k=5)
...@@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank): ...@@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
def test_noop_until_time(): def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes. """Verify metrics aren't collected until enough time passes.
""" """
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0, spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0, spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = 0 spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0 collect_interval_s = 5.0
timer = MagicMock() timer = MagicMock()
...@@ -91,7 +91,7 @@ def test_noop_until_time(): ...@@ -91,7 +91,7 @@ def test_noop_until_time():
collect_interval_s + 0.1, collect_interval_s + 0.1 collect_interval_s + 0.1, collect_interval_s + 0.1
] ]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
...@@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool):
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens( max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k) num_draft_tokens, k)
rej_sampler = MagicMock() spec_decode_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long, dtype=torch.long,
device='cuda') device='cuda')
rej_sampler.num_draft_tokens = num_draft_tokens spec_decode_sampler.num_draft_tokens = num_draft_tokens
collect_interval_s = 5.0 collect_interval_s = 5.0
timer = MagicMock() timer = MagicMock()
...@@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
] ]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
......
import random import random
from typing import Dict, List
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -84,6 +86,7 @@ def test_same_output_for_single_step(): ...@@ -84,6 +86,7 @@ def test_same_output_for_single_step():
block_size, block_size,
num_gpu_blocks, num_gpu_blocks,
seed, seed,
model_runner_cls=TP1DraftModelRunner,
) )
worker = create_worker( worker = create_worker(
Worker, Worker,
...@@ -115,7 +118,8 @@ def test_same_output_for_single_step(): ...@@ -115,7 +118,8 @@ def test_same_output_for_single_step():
actual_output, _ = multi_step_worker.sampler_output( actual_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group), seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps) sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
assert len(actual_output) == num_steps assert len(actual_output) == num_steps
actual_output = actual_output[0] actual_output = actual_output[0]
...@@ -167,6 +171,7 @@ def test_same_output_for_multi_step(): ...@@ -167,6 +171,7 @@ def test_same_output_for_multi_step():
block_size, block_size,
num_gpu_blocks, num_gpu_blocks,
seed, seed,
model_runner_cls=TP1DraftModelRunner,
) )
worker = create_worker( worker = create_worker(
...@@ -206,11 +211,12 @@ def test_same_output_for_multi_step(): ...@@ -206,11 +211,12 @@ def test_same_output_for_multi_step():
multi_step_output, _ = multi_step_worker.sampler_output( multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list), seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps) sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
# Run single-step repeatedly. # Run single-step repeatedly.
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
single_step_output = [] single_step_output: List[SamplerOutput] = []
continuations = [[1] for _ in prompts] continuations = [[1] for _ in prompts]
set_random_seed(seed) set_random_seed(seed)
...@@ -232,11 +238,15 @@ def test_same_output_for_multi_step(): ...@@ -232,11 +238,15 @@ def test_same_output_for_multi_step():
continuations[i].append(seq_group_output.samples[0].output_token) continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison. # Get token ids and logprobs for comparison.
multi_step_output_logprobs = [[] for _ in prompts] multi_step_output_logprobs: List[List[Dict[int,
single_step_output_logprobs = [[] for _ in prompts] Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids = [[] for _ in prompts] single_step_output_logprobs: List[List[Dict[int,
single_step_output_token_ids = [[] for _ in prompts] Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts): for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output, for multi_step, single_step in zip(multi_step_output,
single_step_output): single_step_output):
...@@ -269,6 +279,203 @@ def test_same_output_for_multi_step(): ...@@ -269,6 +279,203 @@ def test_same_output_for_multi_step():
single_step_logprobs) single_step_logprobs)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_correct_output():
"""
In this test we verify that the MultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache(multi_step_worker.cache_engine)
all_seq_ids = {i for i in range(batch_size)}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
for index, output in enumerate(multi_step_output[-1].outputs):
assert (continuations[index][-1] == output.samples[0].output_token)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_incorrect_output():
"""
Tests the MultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
num_mismatch = 0
for index, output in enumerate(multi_step_output[-1].outputs):
if (index % 2) != 0:
assert (continuations[index][-1] == output.samples[0].output_token)
elif (continuations[index][-1] != output.samples[0].output_token):
num_mismatch += 1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert (num_mismatch > 0)
@torch.inference_mode() @torch.inference_mode()
def test_draft_proposals_full_speculation_len(): def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences """Verify Top1Proposer correctly handles case where all sequences
...@@ -310,7 +517,8 @@ def test_draft_proposals_full_speculation_len(): ...@@ -310,7 +517,8 @@ def test_draft_proposals_full_speculation_len():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -348,7 +556,8 @@ def test_draft_proposals_no_speculations(): ...@@ -348,7 +556,8 @@ def test_draft_proposals_no_speculations():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -420,7 +629,8 @@ def test_draft_proposals_mixed_k(): ...@@ -420,7 +629,8 @@ def test_draft_proposals_mixed_k():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), ) num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
......
...@@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match(): ...@@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): ...@@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all(): ...@@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
proposals = proposer.get_spec_proposals( proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest( execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), ) num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
......
import random import random
from collections import defaultdict
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Set
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
...@@ -15,23 +16,26 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker ...@@ -15,23 +16,26 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly) split_num_cache_blocks_evenly)
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, create_sampler_output_list, mock_worker from .utils import create_batch, create_sampler_output_list, mock_worker
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_draft_model(k: int, batch_size: int): def test_correctly_calls_draft_model(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the draft worker with correct """Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out. inputs. Everything else is mocked out.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
exception_secret = 'artificial stop' exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
...@@ -52,15 +56,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): ...@@ -52,15 +56,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int): def test_correctly_calls_target_model(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct """Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out. inputs. Everything else is mocked out.
""" """
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False) target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
...@@ -68,8 +73,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -68,8 +73,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
worker.init_device() worker.init_device()
vocab_size = 32_000 vocab_size = 32_000
...@@ -103,7 +109,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -103,7 +109,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
seen_contexts = [] seen_contexts: List[List[int]] = []
call_args_list = target_worker.execute_model.call_args_list call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
...@@ -116,7 +122,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -116,7 +122,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
for seq_data in seq_group_metadata.seq_data.values(): for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids()) seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts = [] expected_seen_contexts: List[List[int]] = []
for prompt, prev_generated, draft_tokens in zip( for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()): prompts, prev_output_tokens, proposal_token_ids.tolist()):
...@@ -132,8 +138,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -132,8 +138,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_rejection_sampler(k: int, batch_size: int): def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the rejection sampler with """Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out. correct inputs. Everything else is mocked out.
""" """
...@@ -143,15 +152,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -143,15 +152,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size=vocab_size, vocab_size=vocab_size,
use_spec=False) use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector) metrics_collector)
worker.init_device() worker.init_device()
...@@ -198,15 +206,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -198,15 +206,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artificial stop' exception_secret = 'artificial stop'
rejection_sampler.side_effect = ValueError(exception_secret)
spec_decode_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=ExecuteModelRequest( worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
assert len(rejection_sampler.call_args_list) == 1 assert len(spec_decode_sampler.call_args_list) == 1
_, kwargs = rejection_sampler.call_args_list[0] _, kwargs = spec_decode_sampler.call_args_list[0]
actual = SimpleNamespace(**kwargs) actual = SimpleNamespace(**kwargs)
assert torch.equal(actual.bonus_token_ids, assert torch.equal(actual.bonus_token_ids,
...@@ -220,8 +229,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -220,8 +229,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_formats_output(k: int, batch_size: int): def test_correctly_formats_output(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker formats sampler output correctly. """Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out. Everything else is mocked out.
""" """
...@@ -231,15 +243,13 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -231,15 +243,13 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size=vocab_size, vocab_size=vocab_size,
use_spec=False) use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector) metrics_collector)
worker.init_device() worker.init_device()
...@@ -285,24 +295,23 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -285,24 +295,23 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k + 1), size=(batch_size, k + 1),
dtype=torch.int64, dtype=torch.int64,
device='cuda') device='cuda')
for i in range(batch_size): for i in range(batch_size):
minimum_accepted_tokens = 1 minimum_accepted_tokens = 1
rejection_sampler_output[i][ spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1 -random.randint(minimum_accepted_tokens, k + 1):] = -1
rejection_sampler.return_value = rejection_sampler_output spec_decode_sampler.return_value = spec_decode_sampler_output
output = worker.execute_model(execute_model_req=ExecuteModelRequest( output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
expected_output = create_sampler_output_list( expected_output = create_sampler_output_list(
token_ids=rejection_sampler_output.transpose(0, 1), token_ids=spec_decode_sampler_output.transpose(0, 1),
probs=[None for _ in range(k + 1)], probs=[None for _ in range(k + 1)],
logprobs=[None for _ in range(k + 1)]) logprobs=[None for _ in range(k + 1)])
...@@ -310,8 +319,14 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -310,8 +319,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
next(iter(seq_group_metadata.seq_data.keys())) next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list for seq_group_metadata in seq_group_metadata_list
] ]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} seq_id: []
for seq_id in seq_ids
}
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
for step in output: for step in output:
for seq_group in step: for seq_group in step:
...@@ -343,8 +358,11 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -343,8 +358,11 @@ def test_correctly_formats_output(k: int, batch_size: int):
@pytest.mark.parametrize('k', [1, 2]) @pytest.mark.parametrize('k', [1, 2])
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('returns_metrics', [True, False]) @pytest.mark.parametrize('returns_metrics', [True, False])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker collects metrics. """Verify SpecDecodeWorker collects metrics.
""" """
vocab_size = 32_000 vocab_size = 32_000
...@@ -353,16 +371,17 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -353,16 +371,17 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size=vocab_size, vocab_size=vocab_size,
use_spec=False) use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker,
metrics_collector) target_worker,
spec_decode_sampler,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
proposal_token_ids = torch.randint(low=0, proposal_token_ids = torch.randint(low=0,
...@@ -407,17 +426,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -407,17 +426,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k + 1), size=(batch_size, k + 1),
dtype=torch.int64, dtype=torch.int64,
device='cuda') device='cuda')
for i in range(batch_size): for i in range(batch_size):
minimum_accepted_tokens = 1 minimum_accepted_tokens = 1
rejection_sampler_output[i][ spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1 -random.randint(minimum_accepted_tokens, k + 1):] = -1
spec_decode_sampler.return_value = spec_decode_sampler_output
rejection_sampler.return_value = rejection_sampler_output
mock_rejsample_metrics = MagicMock( mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None spec=SpecDecodeWorkerMetrics) if returns_metrics else None
...@@ -438,26 +456,30 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -438,26 +456,30 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
@pytest.mark.parametrize('k', [0]) @pytest.mark.parametrize('k', [0])
@pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_k_equals_zero(k: int, batch_size: int): def test_k_equals_zero(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers """Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill. when k is zero. This happens during prefill.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
...@@ -478,27 +500,31 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -478,27 +500,31 @@ def test_k_equals_zero(k: int, batch_size: int):
@pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('k', [0, 5])
@pytest.mark.parametrize('batch_size', [0]) @pytest.mark.parametrize('batch_size', [0])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_empty_input_batch(k: int, batch_size: int): def test_empty_input_batch(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers """Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch. to the workers information without scheduling a batch.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
...@@ -517,20 +543,20 @@ def test_empty_input_batch(k: int, batch_size: int): ...@@ -517,20 +543,20 @@ def test_empty_input_batch(k: int, batch_size: int):
target_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_init_device(): def test_init_device(acceptance_sampler_method: str):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization. well as other GPU initialization.
""" """
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False) target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector) metrics_collector)
worker.init_device() worker.init_device()
draft_worker.init_device.assert_called_once() draft_worker.init_device.assert_called_once()
...@@ -538,22 +564,23 @@ def test_init_device(): ...@@ -538,22 +564,23 @@ def test_init_device():
target_worker.init_device.assert_called_once() target_worker.init_device.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once() metrics_collector.init_gpu_tensors.assert_called_once()
rejection_sampler.init_gpu_tensors.assert_called_once() spec_decode_sampler.init_gpu_tensors.assert_called_once()
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_initialize_cache(): def test_initialize_cache(acceptance_sampler_method):
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers. workers.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs) worker.initialize_cache(**kwargs)
...@@ -566,19 +593,20 @@ def test_initialize_cache(): ...@@ -566,19 +593,20 @@ def test_initialize_cache():
@pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_determine_num_available_blocks(available_gpu_blocks: int, def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int, available_cpu_blocks: int,
target_cache_block_size_bytes: int, target_cache_block_size_bytes: int,
draft_kv_size_bytes: int): draft_kv_size_bytes: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks. """Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker. split the blocks between proposer and scorer worker.
""" """
draft_worker = mock_worker(cls=MultiStepWorker) draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker() target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.determine_num_available_blocks.return_value = ( target_worker.determine_num_available_blocks.return_value = (
...@@ -587,8 +615,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int, ...@@ -587,8 +615,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
target_cache_block_size_bytes) target_cache_block_size_bytes)
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(
metrics_collector) draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
...@@ -618,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int, ...@@ -618,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
assert (num_blocks * target_cache_block_size_bytes) + ( assert (num_blocks * target_cache_block_size_bytes) + (
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks * num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
target_cache_block_size_bytes) target_cache_block_size_bytes)
@torch.inference_mode()
def test_populate_seq_ids_with_bonus_tokens():
"""
Verify that a call to _create_output_sampler_list correctly updates
seq_with_bonus_token_in_last_step.
seq_with_bonus_token_in_last_step is an internal data structure in
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
tokens by the target model in their last forward pass. This state is
maintained only for models relying on the KV cache, such as those using
the MultiStepWorker.
"""
batch_size = 10
k = 5
vocab_size = 10000
num_sequences_with_bonus_tokens = 5
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
target_worker.device = 'cuda'
set_random_seed(1)
draft_worker = mock_worker(cls=MultiStepWorker)
draft_worker.device = 'cuda'
# The sequence_ids attached to each sequence in the batch.
# The sequence at index i has seq_id assigned_seq_ids[i]
assigned_seq_ids = list(range(batch_size))
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
seq_ids=assigned_seq_ids,
prev_output_token_len=10)
target_token_logprobs = torch.rand(batch_size, (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
accepted_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, (k + 1)),
dtype=torch.int64,
device='cuda')
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
for seq_group_metadata in seq_group_metadata_list:
for seq_id in seq_group_metadata.seq_data:
expected_request_id_seq_ids_mapping[
seq_group_metadata.request_id].add(seq_id)
# Generate a random sample of sequence indexes with bonus tokens
seq_indexes_with_bonus_tokens = random.sample(
range(batch_size), num_sequences_with_bonus_tokens)
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
mask[seq_indexes_with_bonus_tokens] = False
# Set the last token ID to -1 for all indices not in
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
# those indices.
accepted_token_ids[mask, -1:] = -1
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
metrics_collector=metrics_collector)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
# the range [0, batch_size + num_extra_sequence_ids).
num_extra_sequence_ids = 10
worker._seq_with_bonus_token_in_last_step = set(
range(batch_size + num_extra_sequence_ids))
worker._create_output_sampler_list(
seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs,
k=k)
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# batch are retained.
# 2. Of the sequence IDs present in the current batch, only those with a
# bonus token are retained in _seq_with_bonus_token_in_last_step.
# Sequence IDs that are present in the current batch but do not have
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
expected_seq_ids_with_bonus_tokens = \
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
additional_sequence_ids = \
set(range(batch_size, batch_size + num_extra_sequence_ids))
assert worker._seq_with_bonus_token_in_last_step == \
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
assert worker._request_id_seq_id_mapping == \
expected_request_id_seq_ids_mapping
@torch.inference_mode()
def test_handle_finished_requests():
"""
Test to verify that finished request IDs are appropriately processed to
update the internal state of the SpecDecodeWorker.
This test initializes the SpecDecodeWorker with mock data, marks certain
requests as finished, and ensures that the corresponding sequence IDs are
correctly removed from the internal mappings.
"""
batch_size = 32
k = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker,
mock_spec_decode_sampler("rejection_sampler"),
metrics_collector)
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
# request ids and corresponding sequence ids.
worker._request_id_seq_id_mapping = \
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
'request-3': {8,9}, 'request-4': {10,11}}
# Initialize seq_with_bonus_token_in_last_step with a few fake
# sequence ids.
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
# Mark requests with ids request-1 and request-3 as finished.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
finished_requests_ids=['request-1', 'request-3'])
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# Verify that request-1 and request-3 are removed from
# request_id_seq_id_mapping
assert worker._request_id_seq_id_mapping == \
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
# Verify that all sequence ids corresponding to 'request-1'
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert worker._seq_with_bonus_token_in_last_step == \
{4,5,10}
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import torch
from vllm.sequence import SequenceGroupMetadata from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import split_batch_by_proposal_len
def test_get_all_seq_ids(): def test_get_all_seq_ids():
...@@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): ...@@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
assert filtered_groups == [] assert filtered_groups == []
assert indices == [] assert indices == []
def mock_spec_decode_sampler(acceptance_sampler_method):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if acceptance_sampler_method == "rejection_sampler":
sampler = MagicMock(spec=RejectionSampler)
sampler.token_id_dtype = torch.int64
return sampler
elif acceptance_sampler_method == "typical_acceptance_sampler":
sampler = MagicMock(spec=TypicalAcceptanceSampler)
sampler.token_id_dtype = torch.int64
return sampler
else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
from itertools import count from itertools import count
from typing import Dict, Iterable, List, Optional, Union from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
import torch import torch
...@@ -12,8 +14,11 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, ...@@ -12,8 +14,11 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput) SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker)
def round_up_to_next_block(seq_len: int, block_size: int) -> int: def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size return (seq_len + block_size - 1) // block_size
...@@ -49,20 +54,21 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]): ...@@ -49,20 +54,21 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
return new_execute_model return new_execute_model
def zero_kv_cache(cache_engine: CacheEngine): def zero_kv_cache(cache_engine: List[CacheEngine]):
assert cache_engine.gpu_cache assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine.gpu_cache: for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_() key_blocks.zero_()
value_blocks.zero_() value_blocks.zero_()
def create_worker(cls: type, def create_worker(cls: Callable[..., T],
model_name: str, model_name: str,
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
seed: int, seed: int,
is_driver_worker: bool = True, is_driver_worker: bool = True,
enforce_eager: bool = True): enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None) -> T:
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
seed=seed, seed=seed,
...@@ -85,6 +91,7 @@ def create_worker(cls: type, ...@@ -85,6 +91,7 @@ def create_worker(cls: type,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
) )
worker.init_device() worker.init_device()
...@@ -159,8 +166,8 @@ def assert_logprobs_dict_allclose( ...@@ -159,8 +166,8 @@ def assert_logprobs_dict_allclose(
def create_sampler_output_list( def create_sampler_output_list(
token_ids: torch.Tensor, token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]], probs: GenericSequence[Optional[torch.Tensor]],
logprobs: Iterable[Optional[torch.Tensor]], logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist() token_ids_by_step = token_ids.tolist()
......
import json import json
import os import os
import pathlib
import subprocess import subprocess
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import openai import openai
import pytest import pytest
import ray import torch
from tensorizer import EncryptionParams
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
# yapf: disable # yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer, TensorSerializer,
is_vllm_tensorized, is_vllm_tensorized,
load_with_tensorizer, load_with_tensorizer,
open_stream, open_stream,
serialize_vllm_model) serialize_vllm_model,
tensorize_vllm_model)
from ..utils import ServerRunner from ..conftest import VllmRunner, cleanup
from ..utils import RemoteOpenAIServer
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
...@@ -42,6 +48,20 @@ def is_curl_installed(): ...@@ -42,6 +48,20 @@ def is_curl_installed():
except (subprocess.CalledProcessError, FileNotFoundError): except (subprocess.CalledProcessError, FileNotFoundError):
return False return False
def get_torch_model(vllm_runner: VllmRunner):
return vllm_runner \
.model \
.llm_engine \
.model_executor \
.driver_worker \
.model_runner \
.model
def write_keyfile(keyfile_path: str):
encryption_params = EncryptionParams.random()
pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
with open(keyfile_path, 'wb') as f:
f.write(encryption_params.key)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def tensorizer_config(): def tensorizer_config():
...@@ -88,12 +108,17 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( ...@@ -88,12 +108,17 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
with vllm_runner(model_ref) as vllm_model: with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key") key_path = tmp_path / (model_ref + ".key")
write_keyfile(key_path)
outputs = vllm_model.generate(prompts, sampling_params) outputs = vllm_model.generate(prompts, sampling_params)
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path) config_for_serializing = TensorizerConfig(
serialize_vllm_model(vllm_model.model.llm_engine, tensorizer_uri=model_path,
config_for_serializing, encryption_keyfile=key_path
encryption_key_path=key_path) )
serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing)
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path) encryption_keyfile=key_path)
...@@ -145,7 +170,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): ...@@ -145,7 +170,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model: with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(vllm_model.model.llm_engine, serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path)) TensorizerConfig(tensorizer_uri=model_path))
with vllm_runner( with vllm_runner(
...@@ -180,7 +205,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ...@@ -180,7 +205,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model: with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(vllm_model.model.llm_engine, serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path)) TensorizerConfig(tensorizer_uri=model_path))
model_loader_extra_config = { model_loader_extra_config = {
...@@ -191,29 +216,24 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ...@@ -191,29 +216,24 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
openai_args = [ openai_args = [
"--model", model_ref, "--dtype", "float16", "--load-format", "--model", model_ref, "--dtype", "float16", "--load-format",
"tensorizer", "--model-loader-extra-config", "tensorizer", "--model-loader-extra-config",
json.dumps(model_loader_extra_config), "--port", "8000" json.dumps(model_loader_extra_config),
] ]
server = ServerRunner.remote(openai_args) with RemoteOpenAIServer(openai_args) as server:
print("Server ready.")
assert ray.get(server.ready.remote())
print("Server ready.")
client = openai.OpenAI( client = server.get_client()
base_url="http://localhost:8000/v1", completion = client.completions.create(model=model_ref,
api_key="token-abc123", prompt="Hello, my name is",
) max_tokens=5,
completion = client.completions.create(model=model_ref, temperature=0.0)
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None assert completion.id is not None
assert len(completion.choices) == 1 assert len(completion.choices) == 1
assert len(completion.choices[0].text) >= 5 assert len(completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length" assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage( assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11) completion_tokens=5, prompt_tokens=6, total_tokens=11)
def test_raise_value_error_on_invalid_load_format(vllm_runner): def test_raise_value_error_on_invalid_load_format(vllm_runner):
...@@ -224,7 +244,9 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner): ...@@ -224,7 +244,9 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
def test_tensorizer_with_tp(vllm_runner): @pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
with pytest.raises(ValueError): with pytest.raises(ValueError):
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
...@@ -238,8 +260,60 @@ def test_tensorizer_with_tp(vllm_runner): ...@@ -238,8 +260,60 @@ def test_tensorizer_with_tp(vllm_runner):
s3_endpoint="object.ord1.coreweave.com", s3_endpoint="object.ord1.coreweave.com",
), ),
tensor_parallel_size=2, tensor_parallel_size=2,
disable_custom_all_reduce=True,
) )
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tmp_path):
model_ref = "EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model
base_model = vllm_runner(
model_ref,
disable_custom_all_reduce=True,
enforce_eager=True,
)
outputs = base_model.generate(prompts, sampling_params)
base_model.model.llm_engine.model_executor.shutdown()
del base_model
cleanup()
# load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
key_path = tmp_path / (model_ref + ".key")
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path,
)
tensorize_vllm_model(
engine_args=EngineArgs(
model=model_ref,
tensor_parallel_size=2,
disable_custom_all_reduce=True,
enforce_eager=True,
),
tensorizer_config=tensorizer_config,
)
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
cleanup()
loaded_vllm_model = vllm_runner(
model_ref,
tensor_parallel_size=2,
load_format="tensorizer",
disable_custom_all_reduce=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
assert outputs == deserialized_outputs
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
model_ref = "facebook/opt-125m" model_ref = "facebook/opt-125m"
...@@ -248,7 +322,7 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): ...@@ -248,7 +322,7 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
with vllm_runner(model_ref) as vllm_model: with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params) outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(vllm_model.model.llm_engine, config) serialize_vllm_model(get_torch_model(vllm_model), config)
assert is_vllm_tensorized(config) assert is_vllm_tensorized(config)
......
...@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, ...@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
max_input_length=None, max_input_length=None,
) )
hashes = [] hashes: List[List[List[int]]] = []
for prefix in prefixes: for prefix in prefixes:
for lora_int_id in concurrent_lora_int_ids: for lora_int_id in concurrent_lora_int_ids:
......
import vllm
def test_embedded_commit_defined():
assert vllm.__commit__ != "COMMIT_HASH_PLACEHOLDER"
# 7 characters is the length of a short commit hash
assert len(vllm.__commit__) >= 7
...@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration(): ...@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration():
assert not logger.propagate assert not logger.propagate
handler = logger.handlers[0] handler = logger.handlers[0]
assert isinstance(handler, logging.StreamHandler)
assert handler.stream == sys.stdout assert handler.stream == sys.stdout
assert handler.level == logging.INFO assert handler.level == logging.INFO
......
...@@ -83,7 +83,7 @@ def test_logits_processors(seed: int, device: str): ...@@ -83,7 +83,7 @@ def test_logits_processors(seed: int, device: str):
device=device, device=device,
pin_memory=is_pin_memory_available()) pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor( logits_processor_output = logits_processor(
embedding=None, lm_head=None,
hidden_states=input_tensor, hidden_states=input_tensor,
sampling_metadata=sampling_metadata) sampling_metadata=sampling_metadata)
......
...@@ -39,7 +39,7 @@ def test_filter_subtensors(): ...@@ -39,7 +39,7 @@ def test_filter_subtensors():
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
for key, tensor in filtered_state_dict.items(): for key, tensor in filtered_state_dict.items():
# NOTE: don't use `euqal` here, as the tensor might contain NaNs # NOTE: don't use `equal` here, as the tensor might contain NaNs
assert tensor is state_dict[key] assert tensor is state_dict[key]
......
...@@ -7,7 +7,8 @@ from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol, ...@@ -7,7 +7,8 @@ from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
import pytest import pytest
from vllm.utils import deprecate_kwargs, get_open_port, merge_async_iterators from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
get_open_port, merge_async_iterators)
from .utils import error_on_warning from .utils import error_on_warning
...@@ -130,3 +131,61 @@ def test_get_open_port(): ...@@ -130,3 +131,61 @@ def test_get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
s3.bind(("localhost", get_open_port())) s3.bind(("localhost", get_open_port()))
os.environ.pop("VLLM_PORT") os.environ.pop("VLLM_PORT")
# Tests for FlexibleArgumentParser
@pytest.fixture
def parser():
parser = FlexibleArgumentParser()
parser.add_argument('--image-input-type',
choices=['pixel_values', 'image_features'])
parser.add_argument('--model-name')
parser.add_argument('--batch-size', type=int)
parser.add_argument('--enable-feature', action='store_true')
return parser
def test_underscore_to_dash(parser):
args = parser.parse_args(['--image_input_type', 'pixel_values'])
assert args.image_input_type == 'pixel_values'
def test_mixed_usage(parser):
args = parser.parse_args([
'--image_input_type', 'image_features', '--model-name',
'facebook/opt-125m'
])
assert args.image_input_type == 'image_features'
assert args.model_name == 'facebook/opt-125m'
def test_with_equals_sign(parser):
args = parser.parse_args(
['--image_input_type=pixel_values', '--model-name=facebook/opt-125m'])
assert args.image_input_type == 'pixel_values'
assert args.model_name == 'facebook/opt-125m'
def test_with_int_value(parser):
args = parser.parse_args(['--batch_size', '32'])
assert args.batch_size == 32
args = parser.parse_args(['--batch-size', '32'])
assert args.batch_size == 32
def test_with_bool_flag(parser):
args = parser.parse_args(['--enable_feature'])
assert args.enable_feature is True
args = parser.parse_args(['--enable-feature'])
assert args.enable_feature is True
def test_invalid_choice(parser):
with pytest.raises(SystemExit):
parser.parse_args(['--image_input_type', 'invalid_choice'])
def test_missing_required_argument(parser):
parser.add_argument('--required-arg', required=True)
with pytest.raises(SystemExit):
parser.parse_args([])
from typing import Dict, List from typing import Any, Dict, List, Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -139,6 +139,15 @@ def create_dummy_logprobs( ...@@ -139,6 +139,15 @@ def create_dummy_logprobs(
} for token_id in complete_sequence_token_ids] } for token_id in complete_sequence_token_ids]
def create_dummy_prompt_logprobs(
complete_sequence_token_ids: List[int]
) -> List[Optional[Dict[int, Any]]]:
# logprob for the first prompt token is None.
logprobs: List[Optional[Dict[int, Any]]] = [None]
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
return logprobs
@pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False]) @pytest.mark.parametrize("skip_special_tokens", [True, False])
...@@ -153,8 +162,8 @@ def test_decode_sequence_logprobs(complete_sequence: str, ...@@ -153,8 +162,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially. # Run sequentially.
seq = create_sequence() seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token = [] sequential_logprobs_text_chosen_token: List[str] = []
sequential_logprobs_text_other_token = [] sequential_logprobs_text_other_token: List[str] = []
for new_token, logprobs in zip(complete_sequence_token_ids, for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs): dummy_logprobs):
seq.append_token_id(new_token, logprobs) seq.append_token_id(new_token, logprobs)
...@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str, ...@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True]) def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
def test_decode_prompt_logprobs(complete_sequence: str, detokenizer: Detokenizer):
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes prompt logprobs correctly.""" """Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, sampling_params = SamplingParams(skip_special_tokens=True,
prompt_logprobs=1) prompt_logprobs=1)
# Run sequentially. # Run sequentially.
...@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str, ...@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
seqs=[seq], seqs=[seq],
sampling_params=sampling_params, sampling_params=sampling_params,
arrival_time=0.0) arrival_time=0.0)
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs) detokenizer.decode_prompt_logprobs_inplace(seq_group,
decoded_prompt_logprobs = dummy_logprobs dummy_logprobs,
position_offset=0)
# First logprob is None.
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
1:] # type: ignore
if skip_special_tokens: # decoded_prompt_logprobs doesn't contain the first token.
# Text for logprobs for the chosen token should be the same as the token_ids = complete_sequence_token_ids
# prompt text. Note that this will only be true if we skip tokenzier = detokenizer.get_tokenizer_for_seq(seq)
# special tokens. text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
assert complete_sequence == "".join([ text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
logprobs[token_id].decoded_token for token_id, logprobs in zip( text = text_full[len(text_first):]
complete_sequence_token_ids, decoded_prompt_logprobs)
]) # Text for logprobs for the chosen token should be the same as the
assert complete_sequence != "".join([ # prompt text. Note that the first logprob is None.
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip( assert text == "".join([
complete_sequence_token_ids, decoded_prompt_logprobs) logprobs[token_id].decoded_token
]) for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
assert text != "".join([
logprobs[token_id + 1].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_prompt_logprobs_chunked_prefill(
vllm_runner,
model,
chunked_prefill_token_size: int,
example_prompts,
):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size
with vllm_runner(model,
dtype="half",
max_logprobs=5,
gpu_memory_utilization=0.5,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:
vllm_sampling_params = SamplingParams(max_tokens=10,
logprobs=5,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
for idx, result in enumerate(vllm_results):
assert result.prompt_logprobs is not None
assert result.prompt_logprobs[0] is None
# Compared detokenized prompts ids to original prompt.
generated_string = ""
for (prompt_token,
prompt_logprobs) in zip(result.prompt_token_ids[1:],
result.prompt_logprobs[1:]):
# prompt_logprobs is a dict of the token_id: logprob
# We select the token_id corresponding to the actual prompt
# Decoded token in the detokenized string corresponding to this
# prompt token.
generated_string += prompt_logprobs[prompt_token].decoded_token
assert generated_string == example_prompts[idx], (
"Detokenized prompt logprobs do not match original prompt")
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
def test_get_llama3_eos_token():
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128009]
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118
import pytest
from transformers.image_processing_utils import BaseImageProcessor
from vllm.transformers_utils.image_processor import get_image_processor
IMAGE_PROCESSOR_NAMES = [
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-34b-hf",
]
@pytest.mark.parametrize("processor_name", IMAGE_PROCESSOR_NAMES)
def test_image_processor_revision(processor_name: str):
# Assume that "main" branch always exists
image_processor = get_image_processor(processor_name, revision="main")
assert isinstance(image_processor, BaseImageProcessor)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match='not a valid git identifier'):
get_image_processor(processor_name, revision="never")
import asyncio import asyncio
import os import os
import sys
from typing import List, Optional
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
...@@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation( ...@@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation(
max_num_seqs=1, max_num_seqs=1,
max_input_length=None) max_input_length=None)
tokenizer_pool.ping() tokenizer_pool.ping()
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
"""Test that Ray tokenizer pool group can recover from failures and
if that's not possible, mark itself as unhealthy."""
class FailingTokenizerGroup(TokenizerGroup):
def __init__(self,
*args,
fail_at: Optional[List[int]] = None,
**kwargs):
super().__init__(*args, **kwargs)
self.i = 0
self.fail_at = fail_at or []
def encode(self, *args, **kwargs):
self.i += 1
if self.i in self.fail_at:
sys.exit(1)
return super().encode(*args, **kwargs)
class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
_worker_cls = FailingTokenizerGroup
# Fail at first iteration
fail_at = [1]
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
# Modify fail at to not fail at all (will be re-read when actor is
# re-initialized).
fail_at[0] = 1000
# We should recover successfully.
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# Check that we have a new actor
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors
# Fail at first iteration
fail_at = [1]
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)
# We should fail after re-initialization.
with pytest.raises(RuntimeError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# check_health should raise the same thing
with pytest.raises(RuntimeError):
tokenizer_group_pool.check_health()
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
# cause a re-initialization.
fail_at = []
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=2,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
# Prompt too long error
with pytest.raises(ValueError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt" * 100,
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# Actors should stay the same.
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
import os
import threading
from concurrent import futures
from typing import Callable, Dict, Iterable, Literal
import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
ExportTraceServiceResponse)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
TraceServiceServicer, add_TraceServiceServicer_to_server)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_TRACES_INSECURE)
from vllm import LLM, SamplingParams
from vllm.tracing import SpanAttributes
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
'array_value']
def decode_value(value: AnyValue):
field_decoders: Dict[FieldName, Callable] = {
"bool_value": (lambda v: v.bool_value),
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
"double_value": (lambda v: v.double_value),
"array_value":
(lambda v: [decode_value(item) for item in v.array_value.values]),
}
for field, decoder in field_decoders.items():
if value.HasField(field):
return decoder(value)
raise ValueError(f"Couldn't decode value: {value}")
def decode_attributes(attributes: Iterable[KeyValue]):
return {kv.key: decode_value(kv.value) for kv in attributes}
class FakeTraceService(TraceServiceServicer):
def __init__(self):
self.request = None
self.evt = threading.Event()
def Export(self, request, context):
self.request = request
self.evt.set()
return ExportTraceServiceResponse()
@pytest.fixture
def trace_service():
"""Fixture to set up a fake gRPC trace service"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
service = FakeTraceService()
add_TraceServiceServicer_to_server(service, server)
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
server.start()
yield service
server.stop(None)
def test_traces(trace_service):
os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true"
sampling_params = SamplingParams(temperature=0.01,
top_p=0.1,
max_tokens=256)
model = "facebook/opt-125m"
llm = LLM(
model=model,
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
)
prompts = ["This is a short prompt"]
outputs = llm.generate(prompts, sampling_params=sampling_params)
timeout = 5
if not trace_service.evt.wait(timeout):
raise TimeoutError(
f"The fake trace service didn't receive a trace within "
f"the {timeout} seconds timeout")
attributes = decode_attributes(trace_service.request.resource_spans[0].
scope_spans[0].spans[0].attributes)
assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model
assert attributes.get(
SpanAttributes.LLM_REQUEST_ID) == outputs[0].request_id
assert attributes.get(
SpanAttributes.LLM_REQUEST_TEMPERATURE) == sampling_params.temperature
assert attributes.get(
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get(
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids)
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
assert attributes.get(
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS) == completion_tokens
metrics = outputs[0].metrics
assert attributes.get(
SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue
ttft = metrics.first_token_time - metrics.arrival_time
assert attributes.get(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
e2e_time = metrics.finished_time - metrics.arrival_time
assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time
...@@ -4,57 +4,120 @@ import sys ...@@ -4,57 +4,120 @@ import sys
import time import time
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List
import openai
import ray import ray
import requests import requests
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.utils import get_open_port from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
if is_hip():
from amdsmi import (amdsmi_get_gpu_vram_usage,
amdsmi_get_processor_handles, amdsmi_init,
amdsmi_shut_down)
@contextmanager
def _nvml():
try:
amdsmi_init()
yield
finally:
amdsmi_shut_down()
else:
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit, nvmlShutdown)
@contextmanager
def _nvml():
try:
nvmlInit()
yield
finally:
nvmlShutdown()
VLLM_PATH = Path(__file__).parent.parent
"""Path to root of the vLLM repository."""
class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
# Path to root of repository so that utilities can be imported by ray workers def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) if auto_port:
if "-p" in cli_args or "--port" in cli_args:
raise ValueError("You have manually specified the port"
"when `auto_port=True`.")
cli_args = cli_args + ["--port", str(get_open_port())]
@ray.remote(num_gpus=1) parser = FlexibleArgumentParser(
class ServerRunner: description="vLLM's remote OpenAI server.")
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds parser = make_arg_parser(parser)
args = parser.parse_args(cli_args)
self.host = str(args.host or 'localhost')
self.port = int(args.port)
def __init__(self, args):
env = os.environ.copy() env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1" # the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
self.proc = subprocess.Popen( self.proc = subprocess.Popen(
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] + [sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
args, cli_args,
env=env, env=env,
stdout=sys.stdout, stdout=sys.stdout,
stderr=sys.stderr, stderr=sys.stderr)
) self._wait_for_server(url=self.url_for("health"),
self._wait_for_server() timeout=self.MAX_SERVER_START_WAIT_S)
def ready(self): def __enter__(self):
return True return self
def _wait_for_server(self): def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate()
def _wait_for_server(self, *, url: str, timeout: float):
# run health check # run health check
start = time.time() start = time.time()
while True: while True:
try: try:
if requests.get( if requests.get(url).status_code == 200:
"http://localhost:8000/health").status_code == 200:
break break
except Exception as err: except Exception as err:
if self.proc.poll() is not None: result = self.proc.poll()
if result is not None and result != 0:
raise RuntimeError("Server exited unexpectedly.") from err raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5) time.sleep(0.5)
if time.time() - start > self.MAX_SERVER_START_WAIT_S: if time.time() - start > timeout:
raise RuntimeError( raise RuntimeError(
"Server failed to start in time.") from err "Server failed to start in time.") from err
def __del__(self): @property
if hasattr(self, "proc"): def url_root(self) -> str:
self.proc.terminate() return f"http://{self.host}:{self.port}"
def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts)
def get_client(self):
return openai.OpenAI(
base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
)
def get_async_client(self):
return openai.AsyncOpenAI(
base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
)
def init_test_distributed_environment( def init_test_distributed_environment(
...@@ -73,13 +136,15 @@ def init_test_distributed_environment( ...@@ -73,13 +136,15 @@ def init_test_distributed_environment(
ensure_model_parallel_initialized(tp_size, pp_size) ensure_model_parallel_initialized(tp_size, pp_size)
def multi_process_tensor_parallel( def multi_process_parallel(
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
test_target, test_target: Any,
) -> None: ) -> None:
# Using ray helps debugging the error when it failed # Using ray helps debugging the error when it failed
# as compared to multiprocessing. # as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(runtime_env={"working_dir": VLLM_PATH}) ray.init(runtime_env={"working_dir": VLLM_PATH})
distributed_init_port = get_open_port() distributed_init_port = get_open_port()
...@@ -102,3 +167,43 @@ def error_on_warning(): ...@@ -102,3 +167,43 @@ def error_on_warning():
warnings.simplefilter("error") warnings.simplefilter("error")
yield yield
@_nvml()
def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
start_time = time.time()
while True:
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
if is_hip():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10
else:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'
print('gpu memory used (GB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')
dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
break
if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
time.sleep(5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment