Unverified Commit d4201e06 authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Bugfix] Make spec. decode respect per-request seed. (#6034)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
parent b5672a11
......@@ -150,9 +150,54 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
generators = [None] * batch_size
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
draft_token_ids, generators)
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int,
device: str):
torch.set_default_device(device)
rejection_sampler = RejectionSampler()
rejection_sampler.init_gpu_tensors(rank=0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
results = []
for _ in range(n_rep):
generators = [
torch.Generator(
device=device).manual_seed(i) if seeded_mask[i] else None
for i in range(batch_size)
]
results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators))
for i in range(batch_size):
if seeded_mask[i]:
for j in range(1, n_rep):
assert torch.equal(results[j][i], results[0][i])
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
......@@ -197,10 +242,11 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
raise AssertionError()
oob_token_ids[0][0] = rogue_token_id
generators = [None] * batch_size
with pytest.raises(AssertionError):
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
draft_token_ids, generators)
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
......@@ -371,11 +417,15 @@ class _CorrectnessTestHelper:
dtype=torch.int64,
device="cuda").repeat(num_samples, 1)
# unseeded
generators = [None]
# Get output tokens via rejection sampling.
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
bonus_token_ids.to("cuda"),
draft_probs.to("cuda"),
draft_token_ids.to("cuda"))
draft_token_ids.to("cuda"),
generators)
# Remove bonus tokens
output_token_ids = output_token_ids[:, :-1].flatten()
......
import asyncio
from itertools import cycle
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union
import pytest
import ray
......@@ -128,7 +128,9 @@ class AsyncLLM:
try:
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
res = asyncio.run(get_output(prompt, sampling_params))
params = sampling_params[i] if isinstance(
sampling_params, Sequence) else sampling_params
res = asyncio.run(get_output(prompt, params))
outputs.append(res)
finally:
ray.shutdown()
......@@ -267,7 +269,31 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len,
temperature=0.0,
seeded=False,
print_tokens=print_tokens,
ensure_all_accepted=ensure_all_accepted)
def run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
"""
prompts = [
"Hello, my name is",
......@@ -286,6 +312,16 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
# sampling params to ignore eos token.
ignore_eos = force_output_len
if seeded:
sampling_params = [
SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
seed=i,
) for i in range(len(prompts))
]
else:
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
......
import pytest
from .conftest import run_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# speculative model
"speculative_model": "JackFram/llama-160m",
# num speculative tokens
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1, 8, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
10,
])
@pytest.mark.parametrize("seed", [1])
def test_seeded_consistency(baseline_llm_generator, batch_size: int,
temperature: float, output_len: int):
"""Verify outputs are consistent across multiple runs with same seed
"""
run_equality_correctness_test(baseline_llm_generator,
baseline_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)
from functools import cached_property
from typing import Tuple
from typing import List, Optional, Tuple
import torch
import torch.jit
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
SpecDecodeStochasticBaseSampler)
class RejectionSampler(SpecDecodeBaseSampler):
class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
......@@ -36,6 +36,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
......@@ -82,6 +83,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
target_probs,
draft_probs,
draft_token_ids,
generators,
))
output_token_ids = self._create_output(
......@@ -98,6 +100,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
......@@ -114,15 +117,25 @@ class RejectionSampler(SpecDecodeBaseSampler):
# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids)
draft_token_ids, generators)
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators, k=k)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(recovered_probs,
num_samples=1).reshape(
batch_size, k)
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
generators=generators,
seed_indices=seed_indices,
# this arg is unused when None but torch.jit requires a list
non_seed_indices=non_seed_indices or [],
).reshape(batch_size, k)
return accepted, recovered_token_ids
def _get_accepted(
......@@ -130,6 +143,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
......@@ -164,10 +178,28 @@ class RejectionSampler(SpecDecodeBaseSampler):
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]
uniform_rand = torch.rand(batch_size,
seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators)
if len(seed_indices) == 0:
uniform_rand = torch.rand_like(selected_target_probs)
else:
uniform_rand = torch.empty_like(selected_target_probs)
for idx in seed_indices:
uniform_rand[idx, :] = torch.rand(1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generators[idx])
if non_seed_indices:
uniform_rand[non_seed_indices, :] = torch.rand(
len(non_seed_indices),
k,
dtype=self.probs_dtype,
device=target_probs.device)
capped_ratio = torch.minimum(
selected_target_probs / selected_draft_probs,
torch.full((1, ), 1, device=target_probs.device))
......@@ -240,6 +272,27 @@ class RejectionSampler(SpecDecodeBaseSampler):
"""
return torch.finfo(self.probs_dtype).tiny
# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@staticmethod
def _split_batch_by_seeded(
generators: List[Optional[torch.Generator]],
k: int = 1,
) -> Tuple[List[int], Optional[List[int]]]:
if all(generator is None for generator in generators):
seed_indices: List[int] = []
non_seed_indices: Optional[List[int]] = None
else:
seed_indices, non_seed_indices = [], []
for i, generator in enumerate(generators):
if generator is None:
non_seed_indices.extend(range(k * i, k * (i + 1)))
else:
seed_indices.extend(range(k * i, k * (i + 1)))
return seed_indices, non_seed_indices
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
......@@ -250,12 +303,25 @@ class RejectionSampler(SpecDecodeBaseSampler):
def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
generators: List[Optional[torch.Generator]],
seed_indices: List[int],
non_seed_indices: List[int],
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1.0)
q = torch.empty_like(probs)
if len(seed_indices) == 0:
q.exponential_(1.0)
else:
q[non_seed_indices].exponential_(1.0)
for idx in seed_indices:
q[idx].exponential_(1.0, generator=generators[idx // k])
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
from abc import abstractmethod
from typing import Optional
from typing import List, Optional
import torch
import torch.jit
......@@ -54,16 +54,6 @@ class SpecDecodeBaseSampler(nn.Module):
def token_id_dtype(self):
return torch.int64
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
......@@ -217,3 +207,36 @@ class SpecDecodeBaseSampler(nn.Module):
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
"""Base class for samplers used for Speculative Decoding verification
step which are deterministic.
"""
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
"""Base class for samplers used for Speculative Decoding verification
step which are stochastic
"""
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
) -> torch.Tensor:
raise NotImplementedError
......@@ -2,10 +2,10 @@ import torch
import torch.jit
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
SpecDecodeDeterministicBaseSampler)
class TypicalAcceptanceSampler(SpecDecodeBaseSampler):
class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
"""Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads"
......
......@@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
SequenceGroupMetadata, SequenceGroupState,
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
......@@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generator = torch.Generator(
device=seq_group_metadata.state.generator.device)
generator.set_state(seq_group_metadata.state.generator.get_state())
state = SequenceGroupState(generator=generator)
else:
state = None
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
......@@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
},
lora_request=None,
token_chunk_size=1,
state=state,
)
def _split_scoring_output(
......
......@@ -9,7 +9,7 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
......@@ -521,11 +521,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
# Sampler arguments
sampler_extra_kwargs = {}
if isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
# Get sequence group state
generators = []
for seq_group_metadata in seq_group_metadata_list:
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generators.append(seq_group_metadata.state.generator)
else:
generators.append(None)
sampler_extra_kwargs["generators"] = generators
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
**sampler_extra_kwargs,
)
# Append output tokens from non-speculative sequences to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment