"torchvision/vscode:/vscode.git/clone" did not exist on "4b2ad55f1b11d70cf2b31a903fbb685fc9f79e6a"
Unverified Commit a7347d9a authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Make sampler less blocking (#1889)

parent f8c688d7
...@@ -6,13 +6,11 @@ import torch.nn as nn ...@@ -6,13 +6,11 @@ import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutput, SequenceOutput) SequenceData, SequenceGroupOutput, SequenceOutput)
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs. """Samples the next tokens from the model's outputs.
...@@ -32,6 +30,7 @@ class Sampler(nn.Module): ...@@ -32,6 +30,7 @@ class Sampler(nn.Module):
def __init__(self, vocab_size: int) -> None: def __init__(self, vocab_size: int) -> None:
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
self._copy_stream: torch.cuda.Stream = torch.cuda.Stream()
def forward( def forward(
self, self,
...@@ -47,40 +46,38 @@ class Sampler(nn.Module): ...@@ -47,40 +46,38 @@ class Sampler(nn.Module):
logits = _get_logits(hidden_states, embedding, embedding_bias, logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size) self.vocab_size)
_, vocab_size = logits.shape
# Apply logits processors (if any). # Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata) logits = _apply_logits_processors(logits, sampling_metadata)
# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with torch.cuda.stream(self._copy_stream):
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)
torch.cuda.current_stream().wait_stream(self._copy_stream)
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
presence_penalties, frequency_penalties, repetition_penalties = ( if do_penalties:
_get_penalties(sampling_metadata)) logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
assert len(presence_penalties) == logits.shape[0] sampling_tensors.output_tokens,
assert len(frequency_penalties) == logits.shape[0] sampling_tensors.presence_penalties,
assert len(repetition_penalties) == logits.shape[0] sampling_tensors.frequency_penalties,
logits = _apply_penalties(logits, sampling_metadata, sampling_tensors.repetition_penalties)
presence_penalties, frequency_penalties,
repetition_penalties)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(sampling_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(temperatures,
dtype=logits.dtype,
device=logits.device)
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1)) logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
# Apply top-p and top-k truncation. if do_top_p_top_k:
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps,
sampling_metadata, self.vocab_size) sampling_tensors.top_ks)
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
if do_min_p: if do_min_p:
logits = _apply_min_p(logits, min_ps) logits = _apply_min_p(logits, sampling_tensors.min_ps)
# We use float32 for probabilities and log probabilities. # We use float32 for probabilities and log probabilities.
# Compute the probabilities. # Compute the probabilities.
...@@ -120,32 +117,6 @@ def _prune_hidden_states( ...@@ -120,32 +117,6 @@ def _prune_hidden_states(
sampling_metadata.selected_token_indices) sampling_metadata.selected_token_indices)
def _get_penalties(
sampling_metadata: SamplingMetadata
) -> Tuple[List[float], List[float], List[float]]:
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: We do not apply presence and frequency penalties for the
# prompt token positions where we don't sample new tokens.
prompt_len = sampling_metadata.prompt_lens[i]
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
return presence_penalties, frequency_penalties, repetition_penalties
def _get_prompt_and_output_tokens( def _get_prompt_and_output_tokens(
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Tuple[List[List[int]], List[List[int]]]: ) -> Tuple[List[List[int]], List[List[int]]]:
...@@ -168,25 +139,16 @@ def _get_prompt_and_output_tokens( ...@@ -168,25 +139,16 @@ def _get_prompt_and_output_tokens(
def _get_bin_counts_and_mask( def _get_bin_counts_and_mask(
logits: torch.Tensor, tokens: torch.Tensor,
tokens: List[List[int]],
vocab_size: int, vocab_size: int,
num_seqs: int, num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
max_len = max(len(tokens) for tokens in tokens)
padded_tokens = [
tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens
]
tokens_tensor = torch.tensor(padded_tokens,
dtype=torch.long,
device=logits.device)
# Compute the bin counts for the tokens. # Compute the bin counts for the tokens.
# vocab_size + 1 for padding. # vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1), bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long, dtype=torch.long,
device=logits.device) device=tokens.device)
bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor)) bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size] bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0 mask = bin_counts > 0
...@@ -217,45 +179,16 @@ def _apply_logits_processors( ...@@ -217,45 +179,16 @@ def _apply_logits_processors(
return logits return logits
def _apply_penalties( def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
logits: torch.Tensor, output_tokens_tensor: torch.Tensor,
sampling_metadata: SamplingMetadata, presence_penalties: torch.Tensor,
presence_penalties: List[float], frequency_penalties: torch.Tensor,
frequency_penalties: List[float], repetition_penalties: torch.Tensor) -> torch.Tensor:
repetition_penalties: List[float],
) -> torch.Tensor:
num_seqs, vocab_size = logits.shape num_seqs, vocab_size = logits.shape
for i in range(num_seqs): _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
p = presence_penalties[i] num_seqs)
f = frequency_penalties[i]
r = repetition_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
r - 1.0) < _SAMPLING_EPS:
continue
break
else:
# Return early if all sequences have zero penalties.
return logits
prompt_tokens, output_tokens = (
_get_prompt_and_output_tokens(sampling_metadata))
assert len(prompt_tokens) == logits.shape[0]
assert len(output_tokens) == logits.shape[0]
prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask(
logits, prompt_tokens, vocab_size, num_seqs)
output_bin_counts, output_mask = _get_bin_counts_and_mask( output_bin_counts, output_mask = _get_bin_counts_and_mask(
logits, output_tokens, vocab_size, num_seqs) output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = torch.tensor(repetition_penalties,
dtype=logits.dtype,
device=logits.device)
frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype,
device=logits.device)
presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype,
device=logits.device)
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0 repetition_penalties[~(prompt_mask | output_mask)] = 1.0
...@@ -264,109 +197,65 @@ def _apply_penalties( ...@@ -264,109 +197,65 @@ def _apply_penalties(
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
return logits return logits
def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
if temperature < _SAMPLING_EPS:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
temperatures += [temperature] * len(seq_ids)
return temperatures
def _get_top_p_top_k_min_p(
sampling_metadata: SamplingMetadata,
vocab_size: int,
) -> Tuple[List[float], List[int], List[float]]:
top_ps: List[float] = []
top_ks: List[int] = []
min_ps: List[float] = []
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p
min_p = sampling_params.min_p
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation.
top_k = vocab_size if top_k == -1 else top_k
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = sampling_metadata.prompt_lens[i]
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
return top_ps, top_ks, min_ps
def _apply_top_p_top_k( def _apply_top_p_top_k(
logits: torch.Tensor, logits: torch.Tensor,
top_ps: List[float], p: torch.Tensor,
top_ks: List[int], k: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
logits_sort, logits_idx = logits.sort(dim=-1, descending=True) logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
# Apply top-p. # Apply top-p.
probs_sort = logits_sort.softmax(dim=-1) probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1) probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) top_p_mask = probs_sum > p.unsqueeze_(dim=1)
logits_sort[top_p_mask] = -float("inf")
# Apply top-k. # Apply top-k.
# Create a mask for the top-k elements. # Create a mask for the top-k elements.
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze(dim=1) top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
logits_sort[top_k_mask] = -float("inf")
# Final mask.
mask = (top_p_mask | top_k_mask)
logits_sort.masked_fill_(mask, -float("inf"))
# Re-sort the probabilities. # Re-sort the probabilities.
logits = torch.gather(logits_sort, src = torch.arange(logits_idx.shape[-1],
dim=-1, device=logits_idx.device).expand_as(logits_idx)
index=torch.argsort(logits_idx, dim=-1)) logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
index=logits_idx,
src=src)
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
return logits return logits
def _apply_min_p( def _apply_min_p(
logits: torch.Tensor, logits: torch.Tensor,
min_ps: List[float], min_p: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Adapted from Adapted from
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
""" """
min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device)
probs = torch.softmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True) top_probs, _ = probs.max(dim=-1, keepdim=True)
scaled_min_p = min_p.unsqueeze(dim=1) * top_probs scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill(tokens_to_remove, -float("inf")) logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
return logits return logits
def _greedy_sample( def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]], selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor, samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
samples = torch.argmax(logprobs, dim=-1).cpu() samples = samples.tolist()
sample_idx = 0 sample_idx = 0
results = [] results = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
...@@ -375,27 +264,19 @@ def _greedy_sample( ...@@ -375,27 +264,19 @@ def _greedy_sample(
assert num_parent_seqs == 1, ( assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.") "Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples[sample_idx].item()] next_token_ids = [samples[sample_idx]]
results.append((next_token_ids, parent_ids)) results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results return results
def _random_sample( def _random_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]], selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool], is_prompts: List[bool],
probs: torch.Tensor, random_samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
# Find the maximum best_of value of the prompt phase requests. # Find the maximum best_of value of the prompt phase requests.
max_best_of = 1 random_samples = random_samples.cpu()
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
if is_prompt:
seq_ids, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of)
random_samples = torch.multinomial(probs,
num_samples=max_best_of,
replacement=True).cpu()
sample_idx = 0 sample_idx = 0
results = [] results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
...@@ -403,8 +284,6 @@ def _random_sample( ...@@ -403,8 +284,6 @@ def _random_sample(
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * sampling_params.best_of parent_ids = [0] * sampling_params.best_of
next_token_ids = random_samples[ next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist() sample_idx, :sampling_params.best_of].tolist()
...@@ -415,7 +294,6 @@ def _random_sample( ...@@ -415,7 +294,6 @@ def _random_sample(
num_parent_seqs, 0].tolist() num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids)) results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
assert sample_idx == probs.size(0)
return results return results
...@@ -472,6 +350,28 @@ def _beam_search_sample( ...@@ -472,6 +350,28 @@ def _beam_search_sample(
return results return results
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def _multinomial(
probs: torch.Tensor,
num_samples: int,
):
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1)
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
def _sample( def _sample(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
...@@ -485,28 +385,51 @@ def _sample( ...@@ -485,28 +385,51 @@ def _sample(
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {}
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType: for sampling_type in SamplingType:
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_indices = categorized_sample_indices[sampling_type] sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices) num_tokens = len(sample_indices)
if num_tokens == 0: if num_tokens == 0:
continue continue
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
category_logprobs = logprobs[sample_indices] greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
sample_results = _greedy_sample(seq_groups, category_logprobs) elif sampling_type == SamplingType.RANDOM:
max_best_of = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of)
multinomial_samples = _multinomial(probs[sample_indices],
max_best_of)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
# GPU<->CPU sync happens in the loop below.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type == SamplingType.RANDOM: elif sampling_type == SamplingType.RANDOM:
category_probs = probs[sample_indices]
sample_results = _random_sample(seq_groups, is_prompts, sample_results = _random_sample(seq_groups, is_prompts,
category_probs) multinomial_samples)
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
category_logprobs = logprobs[sample_indices]
sample_results = _beam_search_sample(seq_groups, is_prompts, sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data, sampling_metadata.seq_data,
category_logprobs) beam_search_logprobs)
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sample_results_dict.update(zip(seq_group_ids, sample_results)) sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results = [ sample_results = [
...@@ -557,7 +480,7 @@ def _get_logprobs( ...@@ -557,7 +480,7 @@ def _get_logprobs(
batched_logprobs_query_result = logprobs[[ batched_logprobs_query_result = logprobs[[
batched_logprobs_query_seq_indices, batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices batched_logprobs_query_token_indices
]].cpu() ]]
# Batched query for logprobs of topk tokens # Batched query for logprobs of topk tokens
if largest_num_logprobs > 0: if largest_num_logprobs > 0:
...@@ -569,6 +492,8 @@ def _get_logprobs( ...@@ -569,6 +492,8 @@ def _get_logprobs(
else: else:
top_logprobs, top_token_ids = None, None top_logprobs, top_token_ids = None, None
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
# Gather results # Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
result_sample_logprobs: List[SampleLogprobs] = [] result_sample_logprobs: List[SampleLogprobs] = []
......
from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
from vllm.utils import in_wsl
_SAMPLING_EPS = 1e-5
class SamplingMetadata: class SamplingMetadata:
...@@ -41,3 +45,186 @@ class SamplingMetadata: ...@@ -41,3 +45,186 @@ class SamplingMetadata:
f"prompt_lens={self.prompt_lens}, " f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, " f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices})") f"categorized_sample_indices={self.categorized_sample_indices})")
@dataclass
class SamplingTensors:
"""Tensors for sampling."""
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
presence_penalties: torch.Tensor
frequency_penalties: torch.Tensor
repetition_penalties: torch.Tensor
prompt_tokens: torch.Tensor
output_tokens: torch.Tensor
@classmethod
def from_sampling_metadata(
cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
device: torch.device,
dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]:
prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = []
top_ks: List[int] = []
temperatures: List[float] = []
top_ps: List[float] = []
min_ps: List[float] = []
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
do_penalties = False
do_top_p_top_k = False
do_min_p = False
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
top_p = sampling_params.top_p
min_p = sampling_params.min_p
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
top_k = vocab_size if top_k == -1 else top_k
if temperature < _SAMPLING_EPS:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
or top_k != vocab_size):
do_top_p_top_k = True
if not do_min_p and min_p > _SAMPLING_EPS:
do_min_p = True
if not do_penalties and (abs(p) >= _SAMPLING_EPS
or abs(f) >= _SAMPLING_EPS
or abs(r - 1.0) >= _SAMPLING_EPS):
do_penalties = True
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get their logprobs
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids:
seq_data = sampling_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
frequency_penalties, repetition_penalties, prompt_tokens,
output_tokens, vocab_size, device, dtype)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
@classmethod
def from_lists(cls, temperatures: List[float], top_ps: List[float],
top_ks: List[int], min_ps: List[float],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
prompt_tokens: List[List[int]],
output_tokens: List[List[int]], vocab_size: int,
device: torch.device,
dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without
# pinned memory.
pin_memory = not in_wsl()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
output_max_len = max(len(tokens) for tokens in output_tokens)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens
]
temperatures_t = torch.tensor(
temperatures,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ps_t = torch.tensor(
top_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
min_ps_t = torch.tensor(
min_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
presence_penalties_t = torch.tensor(
presence_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
frequency_penalties_t = torch.tensor(
frequency_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
repetition_penalties_t = torch.tensor(
repetition_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ks_t = torch.tensor(
top_ks,
device="cpu",
dtype=torch.int,
pin_memory=pin_memory,
)
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
# Because the memory is pinned, we can do non-blocking
# transfer to device.
return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
top_ks=top_ks_t.to(device=device, non_blocking=True),
min_ps=min_ps_t.to(device=device, non_blocking=True),
presence_penalties=presence_penalties_t.to(device=device,
non_blocking=True),
frequency_penalties=frequency_penalties_t.to(device=device,
non_blocking=True),
repetition_penalties=repetition_penalties_t.to(device=device,
non_blocking=True),
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
output_tokens=output_tensor.to(device=device, non_blocking=True),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment