Unverified Commit 947b7941 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Sampler] Vectorized sampling (simplified) (#1048)


Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent 8d926e91
import pytest
import random
from typing import Tuple
from unittest.mock import patch
import torch
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.worker import Worker
class MockLogitsSampler(Sampler):
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size)
self.fake_logits = fake_logits
def forward(self, *args, **kwargs):
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x):
with patch("vllm.model_executor.layers.sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
worker = Worker(None, None, None)
worker.block_size = 16
return input_tensor, fake_logits, sampler, worker
RANDOM_SEEDS = list(range(128))
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_greedy(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == expected[i].item()
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_random(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == i
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_beam(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_mixed(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
expected_tokens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1:
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
top_p=min(random.random() + 0.1, 1),
top_k=random.randint(0, 10) or -1,
n=n,
presence_penalty=random.randint(0, 1),
)
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output:
assert nth_output.output_token in expected_tokens
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Optional, Tuple
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region) gather_from_tensor_model_parallel_region)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceOutputs from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -44,12 +43,8 @@ class Sampler(nn.Module): ...@@ -44,12 +43,8 @@ class Sampler(nn.Module):
hidden_states = _prune_hidden_states(hidden_states, input_metadata) hidden_states = _prune_hidden_states(hidden_states, input_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = _get_logits(hidden_states, embedding, embedding_bias,
if embedding_bias is not None: self.vocab_size)
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size]
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
...@@ -59,7 +54,7 @@ class Sampler(nn.Module): ...@@ -59,7 +54,7 @@ class Sampler(nn.Module):
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties, logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size) frequency_penalties)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
...@@ -90,19 +85,47 @@ class Sampler(nn.Module): ...@@ -90,19 +85,47 @@ class Sampler(nn.Module):
return _sample(probs, logprobs, input_metadata) return _sample(probs, logprobs, input_metadata)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :vocab_size]
return logits
def _prune_hidden_states( def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
last_token_indices = {t: [] for t in SamplingType}
start_idx = 0 start_idx = 0
last_token_indicies: List[int] = [] for i, seq_group in enumerate(input_metadata.seq_groups):
for prompt_len in input_metadata.prompt_lens: seq_ids, sampling_params = seq_group
last_token_indicies.append(start_idx + prompt_len - 1) sampling_type = sampling_params.sampling_type
start_idx += prompt_len if i < input_metadata.num_prompts:
last_token_indicies.extend( assert len(seq_ids) == 1, "Prompt input should have only one seq."
range(start_idx, start_idx + input_metadata.num_generation_tokens)) prompt_len = input_metadata.prompt_lens[i]
return hidden_states.index_select( last_token_indices[sampling_type].append(start_idx + prompt_len -
0, torch.tensor(last_token_indicies, device=hidden_states.device)) 1)
start_idx += prompt_len
else:
num_seqs = len(seq_ids)
last_token_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
all_last_token_indices = []
for sampling_type in SamplingType:
all_last_token_indices.extend(last_token_indices[sampling_type])
all_last_token_indices = torch.tensor(all_last_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, all_last_token_indices)
def _get_penalties( def _get_penalties(
...@@ -149,11 +172,8 @@ def _apply_penalties( ...@@ -149,11 +172,8 @@ def _apply_penalties(
output_tokens: List[List[int]], output_tokens: List[List[int]],
presence_penalties: List[float], presence_penalties: List[float],
frequency_penalties: List[float], frequency_penalties: List[float],
vocab_size: int,
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs = logits.shape[0] num_seqs, vocab_size = logits.shape
# Collect the indices of sequences that have non-zero penalties.
indices = []
for i in range(num_seqs): for i in range(num_seqs):
if not output_tokens[i]: if not output_tokens[i]:
continue continue
...@@ -161,33 +181,40 @@ def _apply_penalties( ...@@ -161,33 +181,40 @@ def _apply_penalties(
f = frequency_penalties[i] f = frequency_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS: if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
continue continue
indices.append(i) break
else:
# Return early if all sequences have zero penalties. # Return early if all sequences have zero penalties.
if not indices:
return logits return logits
bin_counts = [] max_output_len = max(len(tokens) for tokens in output_tokens)
for i in indices: padded_output_tokens = [
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size)) tokens + [vocab_size] * (max_output_len - len(tokens))
bin_counts = np.stack(bin_counts, axis=0) for tokens in output_tokens
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype, ]
device=logits.device) output_tokens_tensor = torch.tensor(padded_output_tokens,
dtype=torch.long,
device=logits.device)
# Compute the bin counts for the output tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=logits.device)
bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(frequency_penalties, frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(presence_penalties, presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype) logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
return logits return logits
...@@ -268,95 +295,154 @@ def _apply_top_p_top_k( ...@@ -268,95 +295,154 @@ def _apply_top_p_top_k(
def _get_topk_logprobs( def _get_topk_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
num_logprobs: Optional[int], num_logprobs: Optional[int],
) -> Dict[int, float]: ) -> List[Dict[int, float]]:
num_seqs = logprobs.size(0)
if num_logprobs is None or num_logprobs == 0: if num_logprobs is None or num_logprobs == 0:
return {} return [{} for _ in range(num_seqs)]
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs) all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
if num_logprobs == 1: num_logprobs,
topk_logprobs = [topk_logprobs.item()] dim=-1)
topk_ids = [topk_ids.item()] all_topk_logprobs = all_topk_logprobs.cpu()
else: all_topk_ids = all_topk_ids.cpu()
topk_logprobs = topk_logprobs.tolist() all_token_to_logprob = []
topk_ids = topk_ids.tolist() for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
token_to_logprob: Dict[int, float] = {}
token_to_logprob: Dict[int, float] = {} for token_id, logprob in zip(topk_ids, topk_logprobs):
for token_id, logprob in zip(topk_ids, topk_logprobs): token_to_logprob[token_id.item()] = logprob.item()
token_to_logprob[token_id] = logprob all_token_to_logprob.append(token_to_logprob)
return token_to_logprob return all_token_to_logprob
def _build_sequence_outputs(
parent_ids: List[int],
next_token_ids: List[int],
selected_token_logprobs: torch.Tensor,
parent_seq_ids: List[int],
parent_logprobs: torch.Tensor,
num_output_logprobs: Optional[int],
) -> List[SequenceOutputs]:
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
seq_outputs: List[SequenceOutputs] = []
for parent_id, next_token_id, token_logprob in zip(
parent_ids, next_token_ids, selected_token_logprobs):
output_logprobs = next_logprobs[parent_id].copy()
output_logprobs[next_token_id] = token_logprob
seq_outputs.append(
SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
output_logprobs))
return seq_outputs
def _sample_from_prompt( def _greedy_sample(
prob: torch.Tensor, selected_seq_groups: List[Tuple[List[int], SamplingParams]],
sampling_params: SamplingParams, logprobs: torch.Tensor,
) -> List[int]: ) -> List[Tuple[List[int], List[int]]]:
if sampling_params.use_beam_search: samples = torch.argmax(logprobs, dim=-1).cpu()
# Beam search. sample_idx = 0
beam_width = sampling_params.best_of results = []
# Sample 2 * beam_width candidates to make sure that with high for seq_group in selected_seq_groups:
# probability we can get `beam_width` candidates in addition to seq_ids, _ = seq_group
# the finished sequences for the next iteration. See num_parent_seqs = len(seq_ids)
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 assert num_parent_seqs == 1, (
# for details. See also HF reference: "Greedy sampling should have only one seq.")
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 parent_ids = list(range(num_parent_seqs))
_, next_token_ids = torch.topk(prob, 2 * beam_width) next_token_ids = [samples[sample_idx].item()]
next_token_ids = next_token_ids.tolist() results.append((next_token_ids, parent_ids))
elif sampling_params.temperature < _SAMPLING_EPS: sample_idx += num_parent_seqs
# Greedy sampling. assert sample_idx == logprobs.size(0)
assert sampling_params.best_of == 1 return results
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else: def _random_sample(
# Random sampling. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
# Sample `best_of` tokens for the prompt. is_prompts: List[bool],
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(prob,
num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
def _sample_from_generation_tokens(
seq_ids: List[int],
probs: torch.Tensor, probs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
# Find the maximum best_of value of the prompt phase requests.
max_best_of = 1
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
if is_prompt:
seq_ids, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of)
random_samples = torch.multinomial(probs,
num_samples=max_best_of,
replacement=True).cpu()
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * sampling_params.best_of
next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist()
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
next_token_ids = random_samples[sample_idx:sample_idx +
num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == probs.size(0)
return results
def _beam_search_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
seq_data: Dict[int, SequenceData],
logprobs: torch.Tensor, logprobs: torch.Tensor,
seq_logprobs: List[float], ) -> List[Tuple[List[int], List[int]]]:
sampling_params: SamplingParams, # We sample 2 * beam_width candidates to make sure that with high
) -> Tuple[List[int], List[int]]: # probability we can get `beam_width` candidates in addition to
# NOTE(woosuk): sampling_params.best_of can be greater than # the finished sequences for the next iteration. See
# len(seq_ids) because some sequences in the group might have # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# been already terminated. # for details. See also HF reference:
if sampling_params.use_beam_search: # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
# Beam search. #
# Add cumulative logprobs for the sequences in the group. # Note: Beam search is not vectorized, so its speed can be slower than
seq_logprobs = torch.tensor(seq_logprobs, # other sampling methods.
dtype=torch.float, sample_idx = 0
device=logprobs.device) results = []
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
vocab_size = logprobs.size(-1) num_parent_seqs = len(seq_ids)
beam_width = len(seq_ids) beam_width = sampling_params.best_of
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width) seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
topk_ids = topk_ids.tolist() if is_prompt:
seq_idx = [i // vocab_size for i in topk_ids] # Prompt phase.
parent_seq_ids = [seq_ids[i] for i in seq_idx] assert num_parent_seqs == 1, (
next_token_ids = [i % vocab_size for i in topk_ids] "Prompt input should have only one seq.")
elif sampling_params.temperature < _SAMPLING_EPS: parent_ids = [0] * (2 * beam_width)
# Greedy sampling. _, next_token_ids = torch.topk(seq_group_logprobs[0],
assert len(seq_ids) == 1 2 * beam_width)
next_token_id = torch.argmax(probs, dim=-1) next_token_ids = next_token_ids.tolist()
next_token_ids = [int(next_token_id.item())] else:
parent_seq_ids = seq_ids # Generation phase.
else: cumulative_logprobs = [
# Random sampling. seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
# Sample 1 token for each sequence in the group. ]
next_token_ids = torch.multinomial(probs, cumulative_logprobs = torch.tensor(
num_samples=1, cumulative_logprobs,
replacement=True) dtype=torch.float,
next_token_ids = next_token_ids.squeeze(dim=-1).tolist() device=seq_group_logprobs.device)
parent_seq_ids = seq_ids seq_group_logprobs = (seq_group_logprobs +
return parent_seq_ids, next_token_ids cumulative_logprobs.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
vocab_size = seq_group_logprobs.size(-1)
parent_ids = [i // vocab_size for i in topk_ids]
next_token_ids = [i % vocab_size for i in topk_ids]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results
def _sample( def _sample(
...@@ -364,65 +450,80 @@ def _sample( ...@@ -364,65 +450,80 @@ def _sample(
logprobs: torch.Tensor, logprobs: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
seq_outputs: SamplerOutput = [] categorized_seq_group_ids = {t: [] for t in SamplingType}
category_num_tokens = {t: 0 for t in SamplingType}
# TODO(woosuk): Optimize.
idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_group_outputs: List[SequenceOutputs] = []
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts: sampling_type = sampling_params.sampling_type
# Generate the next tokens for a prompt input. categorized_seq_group_ids[sampling_type].append(i)
assert len(seq_ids) == 1, "Prompt input should have only one seq." num_seqs = len(seq_ids)
parent_seq_id = seq_ids[0] category_num_tokens[sampling_type] += num_seqs
prob = probs[idx]
logprob = logprobs[idx] seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
idx += 1 category_start_idx = 0
for sampling_type in SamplingType:
# Sample the next tokens. seq_group_ids = categorized_seq_group_ids[sampling_type]
next_token_ids = _sample_from_prompt(prob, sampling_params) seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
# Get top-k log probabilities for the next tokens. is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
next_logprobs = _get_topk_logprobs(logprob, num_tokens = category_num_tokens[sampling_type]
sampling_params.logprobs) if num_tokens == 0:
continue
# Build the output. category_logprobs = logprobs[category_start_idx:category_start_idx +
for next_token_id in next_token_ids: num_tokens]
output_logprobs = next_logprobs.copy() category_probs = probs[category_start_idx:category_start_idx +
output_logprobs[next_token_id] = logprob[next_token_id].item() num_tokens]
seq_group_outputs.append( if sampling_type == SamplingType.GREEDY:
SequenceOutputs(parent_seq_id, next_token_id, sample_results = _greedy_sample(seq_groups, category_logprobs)
output_logprobs)) elif sampling_type == SamplingType.RANDOM:
sample_results = _random_sample(seq_groups, is_prompts,
category_probs)
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data,
category_logprobs)
else: else:
# Generate the next tokens for generation tokens. raise ValueError(f"Unsupported sampling type: {sampling_type}")
# Batched query for logprobs of selected token
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
sample_idx = 0
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
prob = probs[idx:idx + num_parent_seqs] batched_logprobs_query_seq_indices.extend(
logprob = logprobs[idx:idx + num_parent_seqs] [sample_idx + parent_id for parent_id in parent_ids])
idx += num_parent_seqs batched_logprobs_query_token_indices.extend(next_token_ids)
sample_idx += num_parent_seqs
# Sample the next tokens. assert sample_idx == num_tokens
seq_logprobs = [ batched_logprobs_query_result = category_logprobs[[
input_metadata.seq_data[seq_id].cumulative_logprob batched_logprobs_query_seq_indices,
for seq_id in seq_ids batched_logprobs_query_token_indices
] ]].tolist()
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params) # Build the sequence outputs.
sample_idx = 0
# Get top-k log probabilities for the next tokens. result_idx = 0
next_logprobs: Dict[int, Dict[int, float]] = {} for seq_group_id, seq_group, sample_result in zip(
for j, seq_id in enumerate(seq_ids): seq_group_ids, seq_groups, sample_results):
next_logprobs[seq_id] = _get_topk_logprobs( seq_ids, sampling_params = seq_group
logprob[j], sampling_params.logprobs) next_token_ids, parent_ids = sample_result
num_results = len(next_token_ids)
# Build the output. num_parent_seqs = len(seq_ids)
for parent_seq_id, next_token_id in zip(parent_seq_ids, parent_logprobs = category_logprobs[sample_idx:sample_idx +
next_token_ids): num_parent_seqs]
j = seq_ids.index(parent_seq_id) selected_token_logprobs = batched_logprobs_query_result[
output_logprobs = next_logprobs[parent_seq_id].copy() result_idx:result_idx + num_results]
output_logprobs[next_token_id] = logprob[j, seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
next_token_id].item() selected_token_logprobs,
seq_group_outputs.append( seq_ids, parent_logprobs,
SequenceOutputs(parent_seq_id, next_token_id, sampling_params.logprobs)
output_logprobs)) seq_outputs_dict[seq_group_id] = seq_output
seq_outputs.append(seq_group_outputs) sample_idx += num_parent_seqs
result_idx += num_results
return seq_outputs assert sample_idx == num_tokens
category_start_idx += num_tokens
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
BEAM = 2
class SamplingParams: class SamplingParams:
"""Sampling parameters for text generation. """Sampling parameters for text generation.
...@@ -166,6 +174,14 @@ class SamplingParams: ...@@ -166,6 +174,14 @@ class SamplingParams:
if self.top_k != -1: if self.top_k != -1:
raise ValueError("top_k must be -1 when using greedy sampling.") raise ValueError("top_k must be -1 when using greedy sampling.")
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
return SamplingType.RANDOM
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, " return (f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, " f"best_of={self.best_of}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment