Unverified Commit 7d2dcce1 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

Support per-request seed (#2514)

parent dc903e70
import random
from typing import Tuple
from typing import Tuple, List
from unittest.mock import patch
import pytest
import torch
from transformers import GenerationConfig, GenerationMixin
from typing import Optional
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
......@@ -46,15 +47,13 @@ CUDA_DEVICES = [
]
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
def _do_sample(
batch_size: int,
input_tensor: torch.Tensor,
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
):
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
......@@ -63,7 +62,7 @@ def test_sampler_all_greedy(seed: int, device: str):
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
......@@ -71,9 +70,23 @@ def test_sampler_all_greedy(seed: int, device: str):
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
return sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
......@@ -94,28 +107,40 @@ def test_sampler_all_random(seed: int, device: str):
for i in range(batch_size):
fake_logits[i, i] = 1e2
seq_group_metadata_list = []
prompt_lens = []
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
fake_logits[i, i] = 1e2
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
......@@ -123,6 +148,31 @@ def test_sampler_all_random(seed: int, device: str):
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
assert first_sampler_output == second_sampler_output
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
......@@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, input_tensor, sampler, model_runner,
sampling_params)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
......@@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size)
seq_group_metadata_list = []
expected_tokens = []
expected_tokens: List[Optional[List[int]]] = []
prompt_lens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1:
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
elif sampling_type in (1, 2):
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
......@@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
n=n,
presence_penalty=random.randint(0, 1),
)
if sampling_type == 2:
sampling_params.seed = random.randint(0, 10000)
else:
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n))
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
expected_tokens.append(expected)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
......@@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens
def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
if metadata.sampling_params.use_beam_search:
continue
if metadata.sampling_params.seed is not None \
and expected_tokens[i] is None:
# Record seeded random result to compare with results of second invocation
expected_tokens[i] = [
nth_output.output_token
for nth_output in sequence_output.samples
]
continue
for n, nth_output in enumerate(sequence_output.samples):
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n]
else:
# For non-seeded random check that one of the high-logit tokens were chosen
assert nth_output.output_token in expected_tokens[i]
# Test batch
test_sampling(model_runner)
# Shuffle the batch and resample
target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens):
random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index)
# This time, results of seeded random samples will be compared with the corresponding
# sample in the pre-shuffled batch
test_sampling(model_runner)
del model_runner
......
"""Verify that seeded random sampling is deterministic.
Run `pytest tests/samplers/test_seeded_generate.py --forked`.
"""
import copy
import random
from itertools import combinations
import pytest
from vllm.model_executor.utils import set_random_seed
from vllm import SamplingParams
MODEL = "facebook/opt-125m"
RANDOM_SEEDS = list(range(5))
@pytest.fixture
def vllm_model(vllm_runner):
vllm_model = vllm_runner(MODEL, dtype="half")
yield vllm_model
del vllm_model
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_random_sample_with_seed(
vllm_model,
example_prompts,
seed: int,
) -> None:
set_random_seed(seed)
sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=2.0,
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
presence_penalty=random.randint(0, 1),
max_tokens=8,
ignore_eos=True,
)
sampling_params_seed_1 = copy.deepcopy(sampling_params)
sampling_params_seed_1.seed = 100
sampling_params_seed_2 = copy.deepcopy(sampling_params)
sampling_params_seed_2.seed = 200
llm = vllm_model.model
for prompt in example_prompts:
for params in (
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
):
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=params,
)
results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs]
for output in results]
for i in range(0, len(example_prompts), 6):
outputs = all_outputs[i:i + 6]
# verify all non-seeded requests differ
for output_a, output_b in combinations(
(outputs[0], outputs[1], outputs[2], outputs[3]),
2,
):
assert output_a != output_b
# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]
......@@ -387,6 +387,7 @@ class Scheduler:
block_tables=block_tables,
lora_request=seq_group.lora_request,
prefix=seq_group.prefix,
state=seq_group.state,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs
......
......@@ -173,7 +173,6 @@ class EngineArgs:
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
......
......@@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
......@@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
......@@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
seed: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
......@@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
......
......@@ -342,7 +342,9 @@ def _beam_search_sample(
def _multinomial(
probs: torch.Tensor,
num_samples: int,
):
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
generators: Optional[List[torch.Generator]] = None,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
......@@ -352,7 +354,15 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1)
q = torch.empty_like(probs)
if seq_groups is None:
q.exponential_()
else:
sample_idx = 0
for (seq_ids, _), generator in zip(seq_groups, generators):
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(generator=generator)
sample_idx = next_sample_idx
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
......@@ -370,6 +380,7 @@ def _sample(
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {}
multinomial_samples = {}
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
......@@ -385,14 +396,18 @@ def _sample(
is_prompts, sample_indices)
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
elif sampling_type == SamplingType.RANDOM:
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of)
multinomial_samples = _multinomial(probs[sample_indices],
max_best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices], max_best_of, **seeded_args)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
......@@ -407,9 +422,9 @@ def _sample(
sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type == SamplingType.RANDOM:
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups, is_prompts,
multinomial_samples)
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data,
......
......@@ -19,6 +19,7 @@ class SamplingMetadata:
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
......@@ -31,6 +32,7 @@ class SamplingMetadata:
prompt_lens: Optional[List[int]],
selected_token_indices: torch.Tensor,
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
generators: Optional[List[torch.Generator]] = None,
perform_sampling: bool = True,
) -> None:
self.seq_groups = seq_groups
......@@ -38,6 +40,7 @@ class SamplingMetadata:
self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.generators = generators
self.perform_sampling = perform_sampling
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
......
......@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
BEAM = 2
RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
......@@ -56,6 +57,7 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
......@@ -101,6 +103,7 @@ class SamplingParams:
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
......@@ -124,6 +127,7 @@ class SamplingParams:
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.seed = seed
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
......@@ -229,6 +233,8 @@ class SamplingParams:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
if self.seed is not None:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
def __repr__(self) -> str:
......@@ -242,6 +248,7 @@ class SamplingParams:
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
......
......@@ -248,6 +248,14 @@ class Sequence:
f"num_blocks={len(self.logical_token_blocks)})")
@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator: Optional = None
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
......@@ -280,6 +288,7 @@ class SequenceGroup:
self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
@property
def prompt(self) -> str:
......@@ -397,6 +406,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
"""
......@@ -410,6 +420,7 @@ class SequenceGroupMetadata:
block_tables: Dict[int, List[int]],
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
state: Optional[SequenceGroupState] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
......@@ -418,6 +429,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables
self.lora_request = lora_request
self.prefix = prefix
self.state = SequenceGroupState() if state is None else state
@property
def lora_int_id(self) -> int:
......
......@@ -389,6 +389,7 @@ class ModelRunner:
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
......@@ -419,6 +420,10 @@ class ModelRunner:
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += max_subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device="cuda").manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
......@@ -432,6 +437,9 @@ class ModelRunner:
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
......@@ -454,6 +462,7 @@ class ModelRunner:
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
......@@ -536,6 +545,7 @@ class ModelRunner:
prompt_lens=None,
selected_token_indices=metadata_dict["selected_token_indices"],
categorized_sample_indices=None,
generators=None,
perform_sampling=False,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment