Unverified Commit 8437bae6 authored by Cade Daniel's avatar Cade Daniel Committed by GitHub
Browse files

[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling (#3103)

parent f48c6791
...@@ -28,7 +28,7 @@ steps: ...@@ -28,7 +28,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now. num_gpus: 2 # only support 1 or 2 for now.
- label: Engine Test - label: Engine Test
command: pytest -v -s engine command: pytest -v -s engine test_sequence.py
- label: Entrypoints Test - label: Entrypoints Test
command: pytest -v -s entrypoints command: pytest -v -s entrypoints
...@@ -52,6 +52,9 @@ steps: ...@@ -52,6 +52,9 @@ steps:
- label: Worker Test - label: Worker Test
command: pytest -v -s worker command: pytest -v -s worker
- label: Speculative decoding tests
command: pytest -v -s spec_decode
- label: LoRA Test - label: LoRA Test
command: pytest -v -s lora --forked command: pytest -v -s lora --forked
......
import torch
import pytest
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from .utils import mock_worker, create_seq_group_metadata_from_prompts
@pytest.mark.parametrize('num_target_seq_ids', [100])
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
"""Verify all new sequence ids are greater than all input
seq ids.
"""
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
all_seq_ids = [
[1, 3, 5, 7],
list(range(100)) + [0],
[100],
]
for seq_ids in all_seq_ids:
max_seq_id = max(seq_ids)
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
for _ in range(num_target_seq_ids):
assert next(iterator) > max_seq_id
@pytest.mark.parametrize('k', [1, 2, 6])
def test_get_token_ids_to_score(k: int):
"""Verify correct tokens are selected for scoring.
"""
proposal_token_ids = torch.tensor(
list(range(k)),
dtype=torch.int64,
device='cuda',
)
expected_output = [
[],
]
for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist())
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access
actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
]
assert actual_output == expected_output
@pytest.mark.parametrize('k', [1, 2, 6])
def test_create_single_target_seq_group_metadata(k: int):
"""Verify correct creation of a batch-expanded seq group metadata.
"""
prompt_tokens = [1, 2, 3]
prev_output_tokens = [4, 5, 6]
token_ids = list(range(k))
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
token_ids)
block_size = 32
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
[prev_output_tokens], [num_tokens_processed])[0]
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
target_seq_id = 100
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
input_seq_group_metadata,
input_seq_id,
target_seq_id,
token_ids,
)
assert output.request_id == input_seq_group_metadata.request_id
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids(
) == prompt_tokens
assert output.seq_data[target_seq_id].get_output_token_ids(
) == prev_output_tokens + token_ids
assert len(output.block_tables) == 1
assert output.block_tables[
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]
import torch
import math
import pytest
from unittest.mock import MagicMock
from vllm.spec_decode.metrics import AsyncMetricsCollector
def test_initial_call_returns_none():
"""Expect first call to get metrics to return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler)
collector.init_gpu_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None
def test_second_call_returns_metrics():
"""Expect second call to not return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
@pytest.mark.parametrize("rank", [1, 2, 3, 4])
def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(rej_sampler)
collector.init_gpu_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
collect_interval_s + 0.1, collect_interval_s + 0.1
]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
@pytest.mark.parametrize("has_data", [True, False])
def test_initial_metrics_has_correct_values(has_data: bool):
"""Test correctness of metrics data.
"""
if has_data:
num_accepted_tokens = 103
num_emitted_tokens = 104
num_draft_tokens = 105
else:
num_accepted_tokens = 0
num_emitted_tokens = 0
num_draft_tokens = 0
k = 5
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
num_draft_tokens, k)
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = num_draft_tokens
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k)
metrics = collector.maybe_collect_rejsample_metrics(k)
assert metrics.num_spec_tokens == k
assert metrics.accepted_tokens == num_accepted_tokens
assert metrics.draft_tokens == num_draft_tokens
assert metrics.emitted_tokens == num_emitted_tokens
if has_data:
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens
else:
assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency)
...@@ -3,14 +3,15 @@ import random ...@@ -3,14 +3,15 @@ import random
import pytest import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker, DraftModelTop1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from .utils import (create_execute_model_data, create_worker, from .utils import (create_execute_model_data, create_worker,
create_seq_group_metadata_from_prompts, zero_kv_cache, create_seq_group_metadata_from_prompts, zero_kv_cache,
patch_execute_model_with_seeds, patch_execute_model_with_seeds,
assert_logprobs_dict_allclose) assert_logprobs_dict_allclose, create_batch)
@pytest.mark.parametrize('num_steps', list(range(1, 17))) @pytest.mark.parametrize('num_steps', list(range(1, 17)))
...@@ -259,3 +260,160 @@ def test_same_output_for_multi_step(): ...@@ -259,3 +260,160 @@ def test_same_output_for_multi_step():
multi_step_output_logprobs, single_step_output_logprobs): multi_step_output_logprobs, single_step_output_logprobs):
assert_logprobs_dict_allclose(multi_step_logprobs, assert_logprobs_dict_allclose(multi_step_logprobs,
single_step_logprobs) single_step_logprobs)
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
device=device,
max_model_len=2048,
vocab_size=vocab_size,
)
draft_worker.execute_model_multi_step.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(low=0,
high=vocab_size,
size=(batch_size, ),
device=device,
dtype=torch.long),
) for _ in range(k)
]
execute_model_data, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_no_speculations():
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
prompt_len = 10
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
device=device,
max_model_len=prompt_len + k - 1,
vocab_size=vocab_size,
)
execute_model_data, _, _ = create_batch(batch_size,
k,
prompt_len=prompt_len)
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([0, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_mixed_k():
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
small_prompt_len = 5
long_prompt_len = 10
prev_output_token_len = 20
expected_num_proposal_seqs = 6
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
prompt_len = [
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
] + [long_prompt_len
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
device=device,
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
vocab_size=vocab_size,
)
draft_worker.execute_model_multi_step.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(
low=0,
high=vocab_size,
size=(expected_num_proposal_seqs, ),
device=device,
dtype=torch.long),
) for _ in range(k)
]
execute_model_data, _, _ = create_batch(
batch_size,
k,
prompt_len=prompt_len,
prev_output_token_len=prev_output_token_len,
)
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
This diff is collapsed.
from vllm.spec_decode.util import get_all_seq_ids
from vllm.sequence import SequenceGroupMetadata
from vllm.spec_decode.util import split_batch_by_proposal_len
import pytest
from unittest.mock import MagicMock
def test_get_all_seq_ids():
"""Verify get_all_seq_ids extracts all seq ids.
"""
expected_seq_ids = list(range(10)) + list(range(100, 110))
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
seq_data={
seq_id: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
seq_id: MagicMock(),
},
lora_request=None,
) for seq_id in expected_seq_ids
]
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
assert actual_seq_ids == expected_seq_ids
@pytest.fixture
def fake_sequence_group_metadata():
seq_ids = list(range(3))
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=True,
seq_data={
i: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
i: MagicMock(),
},
lora_request=None,
) for i in seq_ids
]
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)
expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
]
expected_indices = [0, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)
expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
]
expected_indices = [1, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_empty_inputs():
filtered_groups, indices = split_batch_by_proposal_len(
[], [], select_proposal_len_zero=True)
assert filtered_groups == []
assert indices == []
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)
assert filtered_groups == []
assert indices == []
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)
assert filtered_groups == []
assert indices == []
import torch import torch
from typing import List, Optional, Dict from typing import List, Optional, Dict, Iterable, Union
from unittest.mock import MagicMock
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData from vllm.sequence import (Logprob, SequenceGroupMetadata, SequenceData,
SamplerOutput, SequenceGroupOutput, SequenceOutput)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from itertools import count
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
...@@ -24,6 +27,11 @@ class ExecuteModelData: ...@@ -24,6 +27,11 @@ class ExecuteModelData:
return dict( return dict(
(field.name, getattr(self, field.name)) for field in fields(self)) (field.name, getattr(self, field.name)) for field in fields(self))
@classmethod
def from_dict(cls, d):
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
return cls(**cleaned)
def round_up_to_next_block(seq_len: int, block_size: int) -> int: def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size return (seq_len + block_size - 1) // block_size
...@@ -50,6 +58,21 @@ def create_execute_model_data( ...@@ -50,6 +58,21 @@ def create_execute_model_data(
) )
def mock_worker(cls=None,
vocab_size: int = 30_000,
max_model_len: int = 2048,
rank: int = 0) -> MagicMock:
if cls is None:
cls = Worker
worker = MagicMock(spec=cls)
worker.vocab_size = vocab_size
worker.max_model_len = max_model_len
worker.rank = rank
worker.device = 'cuda:0'
return worker
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]): def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
seed_iter = iter(rand_seeds) seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model original_execute_model = worker.execute_model
...@@ -117,25 +140,12 @@ def create_seq_group_metadata_from_prompts( ...@@ -117,25 +140,12 @@ def create_seq_group_metadata_from_prompts(
block_size: int, block_size: int,
final_seq_lens: List[int], final_seq_lens: List[int],
continuations: Optional[List[List[int]]] = None, continuations: Optional[List[List[int]]] = None,
num_tokens_processed: Optional[List[int]] = None,
seq_ids: Optional[List[int]] = None, seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]: ) -> List[SequenceGroupMetadata]:
if continuations is None: if continuations is None:
continuations = [[] for _ in prompts] continuations = [[] for _ in prompts]
if num_tokens_processed is None:
# Default to 1 token missing from kv cache for generation sequences.
num_tokens_processed = []
for continuation, prompt in zip(continuations, prompts):
# If prefill, then default to zero tokens processed.
if not continuation:
num_tokens_processed.append(0)
else:
# If generation, then default to all but one tokens processed.
num_tokens_processed.append(
len(continuation) + len(prompt) - 1)
if seq_ids is None: if seq_ids is None:
seq_ids = list(i for i, _ in enumerate(prompts)) seq_ids = list(i for i, _ in enumerate(prompts))
...@@ -155,13 +165,15 @@ def create_seq_group_metadata_from_prompts( ...@@ -155,13 +165,15 @@ def create_seq_group_metadata_from_prompts(
is_prompt=len(cont_token_ids) == 0, is_prompt=len(cont_token_ids) == 0,
seq_data={ seq_data={
i: i:
SequenceData(prompt_token_ids=prompt_token_ids[:] + SequenceData(
cont_token_ids[:]) prompt_token_ids=prompt_token_ids[:],
output_token_ids=cont_token_ids[:],
),
}, },
sampling_params=SamplingParams(temperature=0.0, ), sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]}, block_tables={i: block_allocations[i][:]},
) for i, (prompt_token_ids, cont_token_ids, num_tokens_saved) in ) for i, (prompt_token_ids,
enumerate(zip(prompts, continuations, num_tokens_processed)) cont_token_ids) in enumerate(zip(prompts, continuations))
] ]
...@@ -178,3 +190,68 @@ def assert_logprobs_dict_allclose( ...@@ -178,3 +190,68 @@ def assert_logprobs_dict_allclose(
expected = torch.tensor( expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob) single_step_expected_logprobs[token_id].logprob)
assert torch.allclose(actual, expected) assert torch.allclose(actual, expected)
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
if seq_ids is None:
seq_ids = list(range(batch_size))
return [
SamplerOutput(outputs=[
SequenceGroupOutput(
samples=[
SequenceOutput(
output_token=token_id,
parent_seq_id=seq_ids[seq_index],
logprobs={token_id: 0},
)
],
prompt_logprobs=None,
) for seq_index, token_id in enumerate(token_ids_by_step[step])
],
sampled_token_probs=probs[step],
sampled_token_ids=token_ids[step])
for step in range(num_steps)
]
def create_batch(batch_size,
k,
prompt_len: Union[int, List[int]] = 10,
prev_output_token_len: int = 10,
seq_ids: Optional[List[int]] = None,
num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None):
if block_size is None:
block_size = 8
if num_gpu_blocks is None:
num_gpu_blocks = 2048 // block_size
iterator = count()
if isinstance(prompt_len, int):
prompt_lens = [prompt_len for _ in range(batch_size)]
else:
prompt_lens = prompt_len
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_seq_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
block_size, final_seq_lens,
prev_output_tokens, seq_ids), )
return execute_model_data, prompts, prev_output_tokens
import pytest
from vllm.sequence import SequenceGroupOutput, SamplerOutput, SequenceOutput
@pytest.fixture
def sample_outputs():
return [
SequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
],
prompt_logprobs=None) for i in range(5)
]
@pytest.fixture
def sampler_output(sample_outputs):
return SamplerOutput(outputs=sample_outputs)
def test_sampler_output_initialization(sampler_output, sample_outputs):
assert len(sampler_output) == len(sample_outputs)
assert sampler_output.sampled_token_probs is None
assert sampler_output.sampled_token_ids is None
assert sampler_output.spec_decode_worker_metrics is None
def test_sampler_output_getitem(sampler_output, sample_outputs):
assert sampler_output[2] == sample_outputs[2]
def test_sampler_output_setitem(sampler_output):
new_output = SequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
],
prompt_logprobs=None)
sampler_output[2] = new_output
assert sampler_output[2] == new_output
def test_sampler_output_len(sampler_output, sample_outputs):
assert len(sampler_output) == len(sample_outputs)
def test_sampler_output_eq(sample_outputs):
sampler_output1 = SamplerOutput(outputs=sample_outputs)
sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
assert sampler_output1 == sampler_output2
assert sampler_output1 != sampler_output3
...@@ -21,8 +21,6 @@ class RejectionSampler(nn.Module): ...@@ -21,8 +21,6 @@ class RejectionSampler(nn.Module):
nontrivial latency. nontrivial latency.
""" """
super().__init__() super().__init__()
self.probs_dtype = torch.float32
self.token_id_dtype = torch.int64
self._strict_mode = strict_mode self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are # NOTE: A "bonus token" is accepted iff all proposal tokens are
...@@ -44,6 +42,14 @@ class RejectionSampler(nn.Module): ...@@ -44,6 +42,14 @@ class RejectionSampler(nn.Module):
dtype=torch.long, dtype=torch.long,
device=device) device=device)
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
def forward( def forward(
self, self,
target_probs: torch.Tensor, target_probs: torch.Tensor,
......
...@@ -587,4 +587,4 @@ def _build_sampler_output( ...@@ -587,4 +587,4 @@ def _build_sampler_output(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append( sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
return sampler_output return SamplerOutput(outputs=sampler_output)
...@@ -2,12 +2,16 @@ ...@@ -2,12 +2,16 @@
import copy import copy
import enum import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union, TYPE_CHECKING
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
if TYPE_CHECKING:
import torch
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@dataclass @dataclass
class Logprob: class Logprob:
...@@ -81,6 +85,8 @@ class SequenceData: ...@@ -81,6 +85,8 @@ class SequenceData:
Args: Args:
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output. Set to an empty list if
None.
Attributes: Attributes:
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
...@@ -91,9 +97,13 @@ class SequenceData: ...@@ -91,9 +97,13 @@ class SequenceData:
def __init__( def __init__(
self, self,
prompt_token_ids: List[int], prompt_token_ids: List[int],
output_token_ids: Optional[List[int]] = None,
) -> None: ) -> None:
if output_token_ids is None:
output_token_ids = []
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = [] self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0 self.cumulative_logprob = 0.0
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
...@@ -117,6 +127,12 @@ class SequenceData: ...@@ -117,6 +127,12 @@ class SequenceData:
return self.prompt_token_ids[-1] return self.prompt_token_ids[-1]
return self.output_token_ids[-1] return self.output_token_ids[-1]
def get_prompt_token_ids(self) -> int:
return self.prompt_token_ids
def get_output_token_ids(self) -> int:
return self.output_token_ids
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceData(" return (f"SequenceData("
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
...@@ -506,6 +522,35 @@ class SequenceGroupOutput: ...@@ -506,6 +522,35 @@ class SequenceGroupOutput:
and self.prompt_logprobs == other.prompt_logprobs) and self.prompt_logprobs == other.prompt_logprobs)
# For each sequence group, we generate a list of SequenceOutput object, @dataclass
# each of which contains one possible candidate for the next token. class SamplerOutput:
SamplerOutput = List[SequenceGroupOutput] """For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This datastructure implements methods so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[SequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional["torch.Tensor"] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
from typing import Iterator, List, Tuple, Optional, Dict
from itertools import chain, count
import torch
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData)
from vllm.worker.worker import Worker
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores
SeqId = int
TargetSeqId = int
TokenId = int
class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get
probabilities of speculative tokens according to the scoring model.
Batch expansion converts a list of sequences and multiple query positions
to a new batch of sequences, each with a single query position. This allows
for MQA-like scoring in speculative decoding without requiring an MQA
kernel.
It is strictly less efficient than MQA scoring.
It only supports scoring the top1 proposal tokens of the proposer, instead
of topk/tree.
"""
def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
seq_group_metadata_list: The input sequence group metadata.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
# TODO(cade) perform this on GPU to remove blocking call.
proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
seq_group_metadata_list=target_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
return_python_output=False)
all_tokens, all_probs = self._contract_batch(
original_bs=len(seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=k,
)
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
)
def _expand_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[TokenId],
proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding
proposal tokens, create a new batch where each sequence has a single
query token.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
spec_seqs, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
target_seq_group_metadata_list = self._create_scoring_model_input(
spec_seqs, proposal_token_ids_list)
num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs)
return spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens
def _contract_batch(self, original_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int],
spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
"""
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
batch_size, k = proposals.proposal_token_ids.shape
target_token_ids = target_token_ids.squeeze().reshape(
batch_size, k + 1)
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
self._vocab_size)
all_tokens = torch.full(size=(original_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
all_probs = torch.zeros(original_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
if non_spec_indices:
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
return all_tokens, all_probs
def _create_scoring_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
) -> List[SequenceGroupMetadata]:
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
"""
if not seq_group_metadata_list:
return []
target_seq_ids_iter = self._create_target_seq_id_iterator(
get_all_seq_ids(seq_group_metadata_list))
target_seq_group_metadata = list(
chain.from_iterable(
self._create_target_seq_group_metadata(
seq_group_metadata,
proposal_token_ids,
i,
target_seq_ids_iter,
) for i, seq_group_metadata in enumerate(
seq_group_metadata_list)))
return target_seq_group_metadata
def _create_target_seq_group_metadata(
self,
input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
"""Given an input sequence group metadata and a list of draft tokens,
create a list of target SequenceGroupMetadata, one for each
token id that needs to be scored.
Naive speculative decoding requires K target model scores, one for each
draft model token. However one can add a bonus token such that if each
token is accepted, then a final token may be sampled from the model.
This function creates K+1 target SequenceGroupMetadata to take
advantage of the bonus token.
"""
assert not input_seq_group_metadata.is_prompt, (
"Speculating on "
"prompts not yet supported")
assert len(input_seq_group_metadata.seq_data) == 1, (
"Beam search "
"not supported in speculative decoding")
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for token_ids in token_ids_to_score:
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
))
return target_seq_group_metadata_list
def _create_single_target_seq_group_metadata(
self,
seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
Args:
seq_group_metadata: The metadata for the input sequence.
seq_id: The input sequence ID.
target_seq_id: The corresponding target sequence ID.
token_ids: The list of token ids that are to be appended to the
input sequence.
"""
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids()
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data={
target_seq_id:
SequenceData(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
),
},
sampling_params=seq_group_metadata.sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
)
def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Split the target model output into speculative and non-speculative
output.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
#
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
split_sizes = [
num_scoring_tokens,
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
]
(spec_probs, non_spec_probs
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
target_token_ids, target_probs = sampler_output_to_torch(
[sampler_output])
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = sampler_output_to_torch(
[sampler_output])
return target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
This implementation increments a counter starting at 1 + max of all
provided input sequence ids.
"""
return count(start=max(seq_ids) + 1)
def _get_token_ids_to_score(
self,
full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of
token ids that should be scored.
Returns k+1 output lists. The additional one is used for generating the
bonus token.
Example:
Input: [0, 1, 2, 3] (k=4)
Output: (k+1 lists)
[]
[0]
[0, 1]
[0, 1, 2]
[0, 1, 2, 3]
"""
empty_token_ids = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([
full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids))
])
return token_ids_to_score
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass
from abc import ABC, abstractmethod
import torch
from vllm.sequence import SequenceGroupMetadata
@dataclass
class SpeculativeProposals:
"""Datastructure used to represent proposal tokens from some proposer. It
also tracks how many speculative tokens each sequence has.
"""
# Speculative proposal tokens.
proposal_token_ids: torch.Tensor
# Probabilities of the proposal tokens according to the proposer.
proposal_probs: torch.Tensor
# The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor
def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids.shape}, "
f"proposal_probs={self.proposal_probs.shape}, "
f"proposal_lens={self.proposal_lens.shape})")
@dataclass
class SpeculativeScores:
"""Datastructure used to represent the scores of speculative tokens
according to the scoring model.
"""
# Probabilities of the speculative tokens according to the scoring model.
probs: torch.Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids: torch.Tensor
def __repr__(self):
return (f"SpeculativeScores("
f"probs={self.probs.shape}, "
f"token_ids={self.token_ids.shape})")
class SpeculativeProposer(ABC):
@abstractmethod
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
raise NotImplementedError
class SpeculativeScorer(ABC):
@abstractmethod
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
import torch
from dataclasses import dataclass
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from typing import Optional
from vllm.utils import in_wsl
import time
from typing import Callable
@dataclass
class SpecDecodeWorkerMetrics:
"""Dataclass holding metrics emitted from the spec decode worker.
"""
# The empirical acceptance rate of the proposal method on a per-token basis.
# This is useful for evaluating how well the proposal method aligns with the
# scoring method.
draft_acceptance_rate: float
# The empirical efficiency, measured as the number of tokens emitted by the
# system divided by the number of tokens that could be emitted by the system
# if the proposal method were perfect.
system_efficiency: float
# The number of speculative tokens produced by the proposal method.
draft_tokens: int
# The number of tokens emitted by the entire system.
emitted_tokens: int
# The number of tokens accepted by the scoring model and verification
# routine, e.g. Llama2-70B and lossless rejection sampling.
#
# NOTE: Any token accepted by the verification routine is considered
# accepted (regardless of if the speculative prefix is also accepted). The
# user will usually see less accepted tokens. This metric is helpful when
# evaluating alignment of the proposal method with the scoring model.
accepted_tokens: int
# The number of speculative tokens per sequence.
num_spec_tokens: int
Timer = Callable[[], float]
class AsyncMetricsCollector:
"""Class which copies rejection sampler metrics from the device to CPU on a
non-default Torch stream.
"""
def __init__(self,
rejection_sampler: RejectionSampler,
timer: Optional[Timer] = None,
collect_interval_s: float = 5.0):
self._rejection_sampler = rejection_sampler
self._timer = time.time if timer is None else timer
self._rank: Optional[int] = None
# We don't have a device set yet.
self._copy_stream: Optional[torch.cuda.Stream] = None
self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = not in_wsl()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s
self._last_metrics_collect_time = self._timer()
def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
ready_event = self._in_flight_copy
self._in_flight_copy = None
return self._collect_rejsample_metrics(k, ready_event)
# Otherwise, check if we should start a new copy.
if self._should_collect_rejsample_metrics(self._timer()):
assert self._in_flight_copy is None
self._in_flight_copy = self._copy_rejsample_metrics_async()
return None
def _should_collect_rejsample_metrics(self, now: float) -> bool:
"""Return whether or not this iteration should print rejection sampling
metrics.
"""
if self._rank != 0:
return False
if (now - self._last_metrics_collect_time <
self._rejsample_metrics_collect_interval_s):
return False
return True
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
CPU asynchronously.
Returns a CUDA event recording when the copy is complete.
"""
self._copy_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self._rejection_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.cuda.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready
def _collect_rejsample_metrics(
self, k: int,
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
"""Create metrics object from statistics copied asynchronously.
Args:
k: int. The number of speculative tokens; used to determine system
efficiency.
ready_event: torch.cuda.Event. The CUDA event recording when the
async GPU->CPU copy is complete.
"""
ready_event.synchronize()
accepted_tokens = self._aggregate_num_accepted_tokens.item()
emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
if draft_tokens > 0:
draft_acceptance_rate = accepted_tokens / draft_tokens
else:
draft_acceptance_rate = float("nan")
if num_possible_tokens > 0:
system_efficiency = emitted_tokens / num_possible_tokens
else:
system_efficiency = float("nan")
return SpecDecodeWorkerMetrics(
num_spec_tokens=k,
draft_acceptance_rate=draft_acceptance_rate,
system_efficiency=system_efficiency,
accepted_tokens=accepted_tokens,
draft_tokens=draft_tokens,
emitted_tokens=emitted_tokens,
)
@staticmethod
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
# Divide by k since batch size can be variable.
total_num_spec_seqs = draft_tokens / k
num_accepted_per_seq_if_all_accepted = k + 1
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)
from typing import List, Dict from typing import List, Dict, Optional, Tuple
import copy import copy
import torch import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
from vllm.spec_decode.util import sampler_output_to_torch
class MultiStepWorker(Worker): class MultiStepWorker(Worker):
...@@ -19,6 +21,21 @@ class MultiStepWorker(Worker): ...@@ -19,6 +21,21 @@ class MultiStepWorker(Worker):
requires more thought for MultiStepWorker support. requires more thought for MultiStepWorker support.
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._proposer: Optional[DraftModelTop1Proposer] = None
def init_model(self):
super().init_model()
self._proposer = DraftModelTop1Proposer(
self,
self.device,
self.max_model_len,
self.vocab_size,
)
@torch.inference_mode() @torch.inference_mode()
def execute_model_multi_step( def execute_model_multi_step(
self, self,
...@@ -58,6 +75,26 @@ class MultiStepWorker(Worker): ...@@ -58,6 +75,26 @@ class MultiStepWorker(Worker):
return model_outputs return model_outputs
def get_spec_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
def _append_new_tokens( def _append_new_tokens(
self, model_output: SamplerOutput, self, model_output: SamplerOutput,
seq_group_metadata_list: SequenceGroupMetadata) -> None: seq_group_metadata_list: SequenceGroupMetadata) -> None:
...@@ -85,21 +122,9 @@ class MultiStepWorker(Worker): ...@@ -85,21 +122,9 @@ class MultiStepWorker(Worker):
"""Copy input data structures to remove side-effects when input data """Copy input data structures to remove side-effects when input data
structures are shared with other modules. structures are shared with other modules.
The multi-step worker must be able to append tokens to sequences after Helpful when the vLLM scheduler runs in the same process as the worker.
a forward pass. This necessitates modification of the data structures The alternative is deep-copying (or other form of deep copy); this has
used by the worker. Since these data structures are shared with other performance downsides.
parts of vLLM, like the scheduler, we must take care not to introduce
unexpected side-effects.
When Ray is used to orchestrate worker processes (such as when the
tensor-parallel degree is >1), this is not a problem because the input
datastructures will be serialized and created anew in the worker
process.
However, when Ray is not used to orchestrate the worker processes (such
as when the tensor-parallel degree is 1), this is a problem. We avoid
the problem by shallow-copying the input datastructures (specifically,
the parts that will change in multiple steps).
""" """
# Shallow-copy the list of SequenceGroupMetadata. This allows us to # Shallow-copy the list of SequenceGroupMetadata. This allows us to
...@@ -176,3 +201,166 @@ class MultiStepWorker(Worker): ...@@ -176,3 +201,166 @@ class MultiStepWorker(Worker):
for seq_group_metadata in seq_group_metadata_list): for seq_group_metadata in seq_group_metadata_list):
raise NotImplementedError( raise NotImplementedError(
"MultiStepWorker does not support beam search.") "MultiStepWorker does not support beam search.")
class DraftModelTop1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
draft_worker: MultiStepWorker,
device: str,
max_model_len: int,
vocab_size: int,
):
self._draft_worker = draft_worker
self._device = device
self._max_model_len = max_model_len
self._vocab_size = vocab_size
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
seq_group_metadata_list=nonzero_proposal_len_seqs,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_steps=max_proposal_len,
)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
max_proposal_len=max_proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
return proposals
def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
max_proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length.
"""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
if seq_len + max_proposal_len < self._max_model_len:
proposal_lens.append(max_proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices
def _merge_outputs(
self,
batch_size: int,
max_proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty tensors.
proposal_tokens = torch.zeros(0,
max_proposal_len,
dtype=torch.long,
device=self._device)
proposal_probs = torch.zeros(0,
max_proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device)
proposal_lens = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(size=(batch_size,
*proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = entire_proposal_tokens, entire_proposal_probs
proposal_lens = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens
from typing import List, Tuple, Optional, Dict
from functools import cached_property
import torch
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
from vllm.worker.worker import Worker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.config import CacheConfig
from vllm.spec_decode.util import nvtx_range, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeScorer
class SpecDecodeWorker:
"""Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal
method, such as a small draft model, to speculate ahead of a larger LLM. The
probabilities of the speculative tokens are then determined by the larger
LLM, after which some verification routine determines which (if any) of the
speculative tokens are accepted by the larger LLM.
See https://github.com/vllm-project/vllm/pull/2188 and
https://github.com/vllm-project/vllm/pull/3103 for more info.
The current implementation has the following limitations:
* Only draft-model proposal is implemented (contributions for more forms are
welcome!).
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
future work.
* Only lossless rejection sampling is supported. Contributions adding lossy
verification routines are welcome (e.g. Medusa's typical acceptance).
* All sequences in a batch must have the same proposal length, or zero. This
can be improved by having per-sequence speculation in the future.
* The scoring forward pass is done without an MQA kernel, which is
suboptimal especially as the batch size, proposal length, and sequence
lengths grow. Contributions to add a MQA scoring are welcome once
correctness tests pass.
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
def __init__(
self,
proposer_worker: MultiStepWorker,
scorer_worker: Worker,
rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
):
"""
Create a SpecDecodeWorker.
Args:
proposer_worker: A worker that can produce speculative tokens for
sequences.
scorer_worker: A worker that produces probabilities of speculative
tokens according to some base model. Typically a vanilla vLLM
Worker.
rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.rejection_sampler = rejection_sampler
self._metrics = AsyncMetricsCollector(
rejection_sampler
) if metrics_collector is None else metrics_collector
self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype
self.scorer: SpeculativeScorer = None
def init_model(self) -> None:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self.scorer_worker.init_model()
self.proposer_worker.init_model()
self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
def profile_num_available_blocks(self, block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
cache_dtype: str) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
larger of the two). Then the total memory which would be used by the
scorer cache is divided evenly between the proposer and scorer model KV,
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space,
cache_dtype))
scorer_cache_block_size_bytes = self.scorer_worker.get_cache_block_size_bytes(
block_size, cache_dtype)
proposer_cache_block_size_bytes = self.proposer_worker.get_cache_block_size_bytes(
block_size, cache_dtype)
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig):
"""Initialize the cache engine of the scorer and proposer workers.
"""
self.scorer_worker.init_cache_engine(cache_config)
self.proposer_worker.init_cache_engine(cache_config)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_spec_tokens: int,
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
assert seq_group_metadata_list is not None, (
"speculative decoding "
"requires non-None seq_group_metadata_list")
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
return self._run_no_spec(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return self._run_speculative_decoding_step(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
k=num_spec_tokens,
)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
two.
"""
self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
return_python_output=False)
sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.probs = None
sampler_output.sampled_tokens = None
return [sampler_output]
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
sequence, then scores each speculative token using the scoring worker.
Returns a list of SamplerOutput, each containing a single token per
sequence.
"""
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
k,
proposals,
)
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k)
return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> torch.Tensor:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
"""
proposal_lens_list = proposals.proposal_lens.tolist()
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
_, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
_, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
original_indices = spec_indices + non_spec_indices
proposal_probs = proposal_scores.probs[spec_indices, :-1]
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
accepted_token_ids = self.rejection_sampler(
proposal_probs,
bonus_token_ids,
proposals.proposal_probs,
proposals.proposal_token_ids,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone()
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone()
return accepted_token_ids
def _create_output_sampler_list(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
k: int,
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
seq_ids = get_all_seq_ids(seq_group_metadata_list)
# shape: [k+1, batch_size]
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
1).tolist()
sampler_output_list = []
for token_ids_by_step in accepted_token_ids_by_step:
if all(token_id == -1 for token_id in token_ids_by_step):
break
step_output_token_ids = []
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
step_output_token_ids.append(
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
# TODO Add verifier logprobs.
logprobs={token_id: 0.0},
)
],
prompt_logprobs=None,
))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
k)
if maybe_rejsample_metrics is not None:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics
return sampler_output_list
@cached_property
def _vocab_size(self) -> int:
"""Get the vocab size of the model and make sure it's consistent between
draft and target workers.
"""
vocab_sizes = [
worker.vocab_size
for worker in [self.proposer_worker, self.scorer_worker]
]
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
return vocab_sizes[0]
@property
def rank(self):
return self.scorer_worker.rank
@property
def device(self):
return self.scorer_worker.device
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
proposer_cache_block_size_bytes: int,
total_num_gpu_blocks: int) -> int:
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
allocate to the target model, this function calculates how many blocks
should be given to the draft and target model.
Note that usually the block size, in bytes, of each model is different,
as it's a function of number of KV/layer, number of heads, and hidden
dimension size.
Since the target and draft models allocate the same number of blocks, we
simply calculate the number of blocks where if allocated by both models,
the total memory usage from KV cache is no larger than the number of
blocks allocatable by the target model alone.
"""
new_num_gpu_blocks = int(
total_num_gpu_blocks * scorer_cache_block_size_bytes /
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
return new_num_gpu_blocks
import torch
from typing import List, Tuple
from vllm.sequence import SequenceGroupMetadata, SamplerOutput
from contextlib import contextmanager
from itertools import chain
SeqId = int
def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return list(
chain.from_iterable([
seq_group_metadata.seq_data.keys()
for seq_group_metadata in seq_group_metadata_list
]))
def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
if select_proposal_len_zero:
predicate = lambda proposal_len: proposal_len == 0
else:
predicate = lambda proposal_len: proposal_len != 0
indices = [
i for i, (_, proposal_len
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
if predicate(proposal_len)
]
seq_groups = [
seq_group for seq_group, proposal_len in zip(
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
]
return seq_groups, indices
def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility function which converts a list of SamplerOutput to tensors.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
sampled_token_probs: torch.Tensor
shape: [batch_size, len(sampler_output_list), vocab_size]
"""
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_probs = torch.stack(
[
sampler_output.sampled_token_probs
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
# shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack(
[
sampler_output.sampled_token_ids.flatten()
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
return sampled_token_ids, sampled_token_probs
@contextmanager
def nvtx_range(msg, *args, **kwargs):
"""
Context manager / decorator that pushes an NVTX range at the beginning
of its scope, and pops it at the end. If extra arguments are given,
they are passed as arguments to msg.format().
If running with cuda graphs, you must enable nsys cuda graph profiling.
Arguments:
msg (string): message to associate with the range
"""
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
try:
yield
finally:
torch.cuda.nvtx.range_pop()
...@@ -97,8 +97,6 @@ class ModelRunner: ...@@ -97,8 +97,6 @@ class ModelRunner:
f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB" f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB"
) )
vocab_size = self.model.config.vocab_size
if self.lora_config: if self.lora_config:
assert hasattr( assert hasattr(
self.model, "supported_lora_modules" self.model, "supported_lora_modules"
...@@ -111,7 +109,7 @@ class ModelRunner: ...@@ -111,7 +109,7 @@ class ModelRunner:
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, vocab_size, self.scheduler_config.max_paddings, self.vocab_size,
self.lora_config, self.device, self.model.embedding_modules, self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules) self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
...@@ -607,8 +605,7 @@ class ModelRunner: ...@@ -607,8 +605,7 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage. # Enable top-k sampling to reflect the accurate memory usage.
vocab_size = self.model_config.get_vocab_size() sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
...@@ -774,6 +771,10 @@ class ModelRunner: ...@@ -774,6 +771,10 @@ class ModelRunner:
self.graph_runners.clear() self.graph_runners.clear()
self.cupy_nccl_backend = None self.cupy_nccl_backend = None
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
class CUDAGraphRunner: class CUDAGraphRunner:
......
...@@ -130,8 +130,8 @@ class Worker: ...@@ -130,8 +130,8 @@ class Worker:
# GPU did not change their memory usage during the profiling. # GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory peak_memory = self.init_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size( cache_block_size = self.get_cache_block_size_bytes(
block_size, cache_dtype, self.model_config, self.parallel_config) block_size, cache_dtype)
num_gpu_blocks = int( num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) // (total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size) cache_block_size)
...@@ -232,6 +232,22 @@ class Worker: ...@@ -232,6 +232,22 @@ class Worker:
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.model_runner.list_loras() return self.model_runner.list_loras()
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
def get_cache_block_size_bytes(self, block_size: int,
cache_dtype: str) -> int:
"""Get the size of the KV cache block size in bytes.
"""
return CacheEngine.get_cache_block_size(block_size, cache_dtype,
self.model_config,
self.parallel_config)
def init_distributed_environment( def init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment