Unverified Commit 9d9072a0 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

Implement prompt logprobs & Batched topk for computing logprobs (#1328)


Co-authored-by: default avatarYunmo Chen <16273544+wanmok@users.noreply.github.com>
parent 928de468
...@@ -11,7 +11,7 @@ def main(args: argparse.Namespace): ...@@ -11,7 +11,7 @@ def main(args: argparse.Namespace):
# Test the following prompts. # Test the following prompts.
test_prompts = [ test_prompts = [
("A robot may not injure a human being", ("A robot may not injure a human being",
SamplingParams(temperature=0.0)), SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
("To be or not to be,", ("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?", ("What is the meaning of life?",
......
...@@ -64,7 +64,7 @@ def test_request_tracker(): ...@@ -64,7 +64,7 @@ def test_request_tracker():
stream_5 = tracker.add_request("5") stream_5 = tracker.add_request("5")
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag
tracker.process_request_output( tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True)) RequestOutput("2", "output", [], [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.flag
assert len(finished) == 1 assert len(finished) == 1
......
...@@ -107,6 +107,39 @@ class HfRunner: ...@@ -107,6 +107,39 @@ class HfRunner:
outputs[i] = (output_ids, output_str) outputs[i] = (output_ids, output_str)
return outputs return outputs
def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
) -> List[List[torch.Tensor]]:
all_logprobs = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate(
input_ids.cuda(),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
)
seq_logprobs = []
for hidden_states in output.hidden_states:
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
seq_logprobs.append(logprobs)
all_logprobs.append(seq_logprobs)
return all_logprobs
@pytest.fixture @pytest.fixture
def hf_runner(): def hf_runner():
......
import pytest
import torch
from vllm import SamplingParams
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
model,
dtype,
example_prompts,
):
max_tokens = 5
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=5,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
# Test whether logprobs are included in the results.
for result in vllm_results:
assert result.prompt_logprobs is not None
assert result.outputs[0].logprobs is not None
# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(logprob,
hf_logprob[0][i][token_id].item(),
atol=1e-2,
rtol=1e-2)
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs):
for token_id, logprob in vllm_sample_logprob_dict.items():
torch.testing.assert_close(logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
...@@ -143,7 +143,7 @@ class ModelConfig: ...@@ -143,7 +143,7 @@ class ModelConfig:
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU worker.""" """Returns the number of KV heads per GPU worker."""
# For GPTBigCode & Falcon: # For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the # NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of # multi_query flag is ignored and we use n_head_kv for the number of
# KV heads. # KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
......
...@@ -12,8 +12,8 @@ from vllm.logger import init_logger ...@@ -12,8 +12,8 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs, SequenceGroupMetadata, SequenceGroupOutputs,
SequenceStatus) SequenceOutputs, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter
...@@ -350,9 +350,15 @@ class LLMEngine: ...@@ -350,9 +350,15 @@ class LLMEngine:
eos_token_id=self.tokenizer.eos_token_id)) eos_token_id=self.tokenizer.eos_token_id))
return current_worst_score >= highest_attainable_score return current_worst_score >= highest_attainable_score
def _process_sequence_group_samples( def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
self, seq_group: SequenceGroup, outputs: SequenceGroupOutputs) -> None:
samples: List[SequenceOutputs]) -> None: # Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
seq_group.prompt_logprobs = prompt_logprobs
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs() existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = { parent_child_dict = {
...@@ -520,8 +526,8 @@ class LLMEngine: ...@@ -520,8 +526,8 @@ class LLMEngine:
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, samples in zip(scheduled_seq_groups, output): for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_samples(seq_group, samples) self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() self.scheduler.free_finished_seq_groups()
......
...@@ -420,7 +420,7 @@ class PagedAttentionWithALiBi(PagedAttention): ...@@ -420,7 +420,7 @@ class PagedAttentionWithALiBi(PagedAttention):
# Generates ALiBi mask for each prompt. # Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens: for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype) bias = torch.arange(prompt_len, dtype=dtype)
# Note(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi # the bias below more accurately follows the original ALiBi
......
...@@ -8,7 +8,8 @@ from vllm.model_executor.input_metadata import InputMetadata ...@@ -8,7 +8,8 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutputs, SequenceOutputs)
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -82,7 +83,12 @@ class Sampler(nn.Module): ...@@ -82,7 +83,12 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
return _sample(probs, logprobs, input_metadata) sample_results = _sample(probs, logprobs, input_metadata)
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, input_metadata, sample_results)
return _build_sampler_output(sample_results, input_metadata,
prompt_logprobs, sample_logprobs)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
...@@ -102,24 +108,28 @@ def _prune_hidden_states( ...@@ -102,24 +108,28 @@ def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
last_token_indices = [] selected_token_indices: List[int] = []
start_idx = 0 start_idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts: if i < input_metadata.num_prompts:
assert len(seq_ids) == 1, "Prompt input should have only one seq." assert len(seq_ids) == 1, "Prompt input should have only one seq."
prompt_len = input_metadata.prompt_lens[i] prompt_len = input_metadata.prompt_lens[i]
last_token_indices.append(start_idx + prompt_len - 1) if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(start_idx, start_idx + prompt_len - 1))
selected_token_indices.append(start_idx + prompt_len - 1)
start_idx += prompt_len start_idx += prompt_len
else: else:
num_seqs = len(seq_ids) num_seqs = len(seq_ids)
last_token_indices.extend(range(start_idx, start_idx + num_seqs)) selected_token_indices.extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs start_idx += num_seqs
last_token_indices = torch.tensor(last_token_indices, selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long, dtype=torch.long,
device=hidden_states.device) device=hidden_states.device)
return hidden_states.index_select(0, last_token_indices) return hidden_states.index_select(0, selected_token_indices)
def _get_penalties( def _get_penalties(
...@@ -127,10 +137,17 @@ def _get_penalties( ...@@ -127,10 +137,17 @@ def _get_penalties(
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
for seq_group in input_metadata.seq_groups: for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty f = sampling_params.frequency_penalty
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: We do not apply presence and frequency penalties for the
# prompt token positions where we don't sample new tokens.
prompt_len = input_metadata.prompt_lens[i]
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
presence_penalties += [p] * len(seq_ids) presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids) frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties return presence_penalties, frequency_penalties
...@@ -138,8 +155,14 @@ def _get_penalties( ...@@ -138,8 +155,14 @@ def _get_penalties(
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
output_tokens: List[List[int]] = [] output_tokens: List[List[int]] = []
for seq_group in input_metadata.seq_groups: for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group seq_ids, sampling_params = seq_group
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
prompt_len = input_metadata.prompt_lens[i]
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id] seq_data = input_metadata.seq_data[seq_id]
output_tokens.append(seq_data.output_token_ids) output_tokens.append(seq_data.output_token_ids)
...@@ -200,7 +223,7 @@ def _apply_penalties( ...@@ -200,7 +223,7 @@ def _apply_penalties(
def _get_temperatures(input_metadata: InputMetadata) -> List[float]: def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# Collect the temperatures for the logits. # Collect the temperatures for the logits.
temperatures: List[float] = [] temperatures: List[float] = []
for seq_group in input_metadata.seq_groups: for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature temperature = sampling_params.temperature
if temperature < _SAMPLING_EPS: if temperature < _SAMPLING_EPS:
...@@ -208,6 +231,10 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]: ...@@ -208,6 +231,10 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# (i.e., greedy sampling or beam search). # (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero. # Set the temperature to 1 to avoid division by zero.
temperature = 1.0 temperature = 1.0
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
temperatures += [temperature] * len(seq_ids) temperatures += [temperature] * len(seq_ids)
return temperatures return temperatures
...@@ -218,13 +245,18 @@ def _get_top_p_top_k( ...@@ -218,13 +245,18 @@ def _get_top_p_top_k(
) -> Tuple[List[float], List[int]]: ) -> Tuple[List[float], List[int]]:
top_ps: List[float] = [] top_ps: List[float] = []
top_ks: List[int] = [] top_ks: List[int] = []
for seq_group in input_metadata.seq_groups: for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p top_p = sampling_params.top_p
# k should not be greater than the vocab size. # k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size) top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation. # k=-1 means no truncation.
top_k = vocab_size if top_k == -1 else top_k top_k = vocab_size if top_k == -1 else top_k
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i]
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
top_ps += [top_p] * len(seq_ids) top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids) top_ks += [top_k] * len(seq_ids)
return top_ps, top_ks return top_ps, top_ks
...@@ -259,49 +291,6 @@ def _apply_top_p_top_k( ...@@ -259,49 +291,6 @@ def _apply_top_p_top_k(
return logits return logits
def _get_topk_logprobs(
logprobs: torch.Tensor,
num_logprobs: Optional[int],
) -> List[Dict[int, float]]:
num_seqs = logprobs.size(0)
if num_logprobs is None or num_logprobs == 0:
return [{} for _ in range(num_seqs)]
all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
num_logprobs,
dim=-1)
all_topk_logprobs = all_topk_logprobs.cpu()
all_topk_ids = all_topk_ids.cpu()
all_token_to_logprob = []
for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
token_to_logprob: Dict[int, float] = {}
for token_id, logprob in zip(topk_ids, topk_logprobs):
token_to_logprob[token_id.item()] = logprob.item()
all_token_to_logprob.append(token_to_logprob)
return all_token_to_logprob
def _build_sequence_outputs(
parent_ids: List[int],
next_token_ids: List[int],
selected_token_logprobs: List[float],
parent_seq_ids: List[int],
parent_logprobs: torch.Tensor,
num_output_logprobs: Optional[int],
) -> List[SequenceOutputs]:
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
seq_outputs: List[SequenceOutputs] = []
for parent_id, next_token_id, token_logprob in zip(
parent_ids, next_token_ids, selected_token_logprobs):
output_logprobs = next_logprobs[parent_id].copy()
output_logprobs[next_token_id] = token_logprob
seq_outputs.append(
SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
output_logprobs))
return seq_outputs
def _greedy_sample( def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]], selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor, logprobs: torch.Tensor,
...@@ -372,7 +361,7 @@ def _beam_search_sample( ...@@ -372,7 +361,7 @@ def _beam_search_sample(
# for details. See also HF reference: # for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
# #
# Note: Beam search is not vectorized, so its speed can be slower than # NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods. # other sampling methods.
sample_idx = 0 sample_idx = 0
results = [] results = []
...@@ -416,79 +405,186 @@ def _sample( ...@@ -416,79 +405,186 @@ def _sample(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> SamplerOutput: ) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = {t: [] for t in SamplingType}
start_idx = 0 start_idx = 0
categorized_seq_ids = {t: [] for t in SamplingType}
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need sample, skip
prompt_len = input_metadata.prompt_lens[i]
start_idx += prompt_len - 1
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
num_seqs = len(seq_ids) num_seqs = len(seq_ids)
categorized_seq_ids[sampling_type].extend( categorized_sample_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs)) range(start_idx, start_idx + num_seqs))
start_idx += num_seqs start_idx += num_seqs
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
for sampling_type in SamplingType: for sampling_type in SamplingType:
seq_group_ids = categorized_seq_group_ids[sampling_type] seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids] seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids] is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
num_tokens = len(categorized_seq_ids[sampling_type]) sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices)
if num_tokens == 0: if num_tokens == 0:
continue continue
category_logprobs = logprobs[categorized_seq_ids[sampling_type]]
category_probs = probs[categorized_seq_ids[sampling_type]]
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
category_logprobs = logprobs[sample_indices]
sample_results = _greedy_sample(seq_groups, category_logprobs) sample_results = _greedy_sample(seq_groups, category_logprobs)
elif sampling_type == SamplingType.RANDOM: elif sampling_type == SamplingType.RANDOM:
category_probs = probs[sample_indices]
sample_results = _random_sample(seq_groups, is_prompts, sample_results = _random_sample(seq_groups, is_prompts,
category_probs) category_probs)
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
category_logprobs = logprobs[sample_indices]
sample_results = _beam_search_sample(seq_groups, is_prompts, sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data, input_metadata.seq_data,
category_logprobs) category_logprobs)
else: else:
raise ValueError(f"Unsupported sampling type: {sampling_type}") raise ValueError(f"Unsupported sampling type: {sampling_type}")
sample_results_dict.update(zip(seq_group_ids, sample_results))
# Batched query for logprobs of selected token sample_results = [
batched_logprobs_query_seq_indices: List[int] = [] sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
batched_logprobs_query_token_indices: List[int] = [] ]
sample_idx = 0 return sample_results
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group def _get_logprobs(
next_token_ids, parent_ids = sample_result logprobs: torch.Tensor,
num_parent_seqs = len(seq_ids) input_metadata: InputMetadata,
sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
int, float]]]]:
# Prepare query indices
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
largest_num_logprobs = 0
sample_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids)
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.prompt_logprobs)
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
seq_ids[0]].prompt_token_ids
batched_logprobs_query_seq_indices.extend( batched_logprobs_query_seq_indices.extend(
[sample_idx + parent_id for parent_id in parent_ids]) sample_idx + j for j in range(prompt_len - 1))
batched_logprobs_query_token_indices.extend(next_token_ids) batched_logprobs_query_token_indices.extend(
sample_idx += num_parent_seqs token_id for token_id in prompt_tokens[1:])
assert sample_idx == num_tokens sample_idx += prompt_len - 1
batched_logprobs_query_result = category_logprobs[[ batched_logprobs_query_seq_indices.extend(
batched_logprobs_query_seq_indices, [sample_idx + parent_id for parent_id in parent_ids])
batched_logprobs_query_token_indices batched_logprobs_query_token_indices.extend(next_token_ids)
]].tolist() if sampling_params.logprobs is not None:
largest_num_logprobs = max(largest_num_logprobs,
# Build the sequence outputs. sampling_params.logprobs)
sample_idx = 0 sample_idx += num_parent_seqs
result_idx = 0 assert sample_idx == logprobs.size(0)
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results): # Batched query for logprobs of selected token
seq_ids, sampling_params = seq_group batched_logprobs_query_result = logprobs[[
next_token_ids, parent_ids = sample_result batched_logprobs_query_seq_indices,
num_results = len(next_token_ids) batched_logprobs_query_token_indices
num_parent_seqs = len(seq_ids) ]].cpu()
parent_logprobs = category_logprobs[sample_idx:sample_idx +
num_parent_seqs] # Batched query for logprobs of topk tokens
selected_token_logprobs = batched_logprobs_query_result[ if largest_num_logprobs > 0:
result_idx:result_idx + num_results] top_logprobs, top_token_ids = torch.topk(logprobs,
seq_output = _build_sequence_outputs(parent_ids, next_token_ids, largest_num_logprobs,
selected_token_logprobs, dim=-1)
seq_ids, parent_logprobs, top_logprobs = top_logprobs.cpu()
sampling_params.logprobs) top_token_ids = top_token_ids.cpu()
seq_outputs_dict[seq_group_id] = seq_output else:
sample_idx += num_parent_seqs top_logprobs, top_token_ids = None, None
result_idx += num_results
assert sample_idx == num_tokens # Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))] result_sample_logprobs: List[SampleLogprobs] = []
sample_idx = 0
query_result_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
# Prompt logprobs
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
num_logprobs = sampling_params.prompt_logprobs
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
seq_ids[0]].prompt_token_ids
group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]:
prompt_logprobs_dict = {
token_id:
batched_logprobs_query_result[query_result_idx].item()
}
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
top_logprobs[sample_idx, :num_logprobs].tolist()))
group_prompt_logprobs.append(prompt_logprobs_dict)
sample_idx += 1
query_result_idx += 1
result_prompt_logprobs.append(group_prompt_logprobs)
else:
result_prompt_logprobs.append(None)
# Sample logprobs
num_logprobs = sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
group_sample_logprobs: SampleLogprobs = []
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
sample_logprobs_dict = {
next_token_id:
batched_logprobs_query_result[query_result_idx].item()
}
query_result_idx += 1
if num_logprobs > 0:
sample_logprobs_dict.update(
zip(
top_token_ids[sample_idx +
parent_id, :num_logprobs].tolist(),
top_logprobs[sample_idx +
parent_id, :num_logprobs].tolist()))
group_sample_logprobs.append(sample_logprobs_dict)
result_sample_logprobs.append(group_sample_logprobs)
sample_idx += len(seq_ids)
return result_prompt_logprobs, result_sample_logprobs
def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
input_metadata: InputMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(input_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids, _ = seq_group
next_token_ids, parent_ids = sample_result
seq_outputs = []
for parent_id, next_token_id, logprobs in zip(parent_ids,
next_token_ids,
group_sample_logprobs):
seq_outputs.append(
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
return sampler_output
...@@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
def tensor_model_parallel_all_reduce(input_): def tensor_model_parallel_all_reduce(input_):
"""All-reduce the input tensor across model parallel group. """All-reduce the input tensor across model parallel group.
Note: This operation is applied in-place on the input tensor. NOTE: This operation is applied in-place on the input tensor.
""" """
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1: if get_tensor_model_parallel_world_size() == 1:
......
...@@ -133,7 +133,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -133,7 +133,7 @@ class ColumnParallelLinear(torch.nn.Module):
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # NOTE: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
self.create_weights(params_dtype) self.create_weights(params_dtype)
......
...@@ -41,7 +41,7 @@ def split_tensor_along_last_dim( ...@@ -41,7 +41,7 @@ def split_tensor_along_last_dim(
last_dim_size = divide(tensor.size()[last_dim], num_partitions) last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split. # Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default. # NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks: if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list) return tuple(chunk.contiguous() for chunk in tensor_list)
......
from typing import Dict, List, Optional from typing import List, Optional
from vllm.sequence import SequenceGroup, SequenceStatus from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
SequenceStatus)
class CompletionOutput: class CompletionOutput:
...@@ -23,7 +24,7 @@ class CompletionOutput: ...@@ -23,7 +24,7 @@ class CompletionOutput:
text: str, text: str,
token_ids: List[int], token_ids: List[int],
cumulative_logprob: float, cumulative_logprob: float,
logprobs: Optional[List[Dict[int, float]]], logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None, finish_reason: Optional[str] = None,
) -> None: ) -> None:
self.index = index self.index = index
...@@ -61,12 +62,14 @@ class RequestOutput: ...@@ -61,12 +62,14 @@ class RequestOutput:
request_id: str, request_id: str,
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
finished: bool, finished: bool,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.prompt_logprobs = prompt_logprobs
self.outputs = outputs self.outputs = outputs
self.finished = finished self.finished = finished
...@@ -91,7 +94,7 @@ class RequestOutput: ...@@ -91,7 +94,7 @@ class RequestOutput:
# NOTE: We need to take care of this case because the sequence # NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the # always has the logprobs of the sampled tokens even if the
# logprobs are not requested. # logprobs are not requested.
logprobs = {} logprobs = None
finshed_reason = SequenceStatus.get_finished_reason(seq.status) finshed_reason = SequenceStatus.get_finished_reason(seq.status)
output = CompletionOutput(seqs.index(seq), seq.output_text, output = CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(), seq.get_output_token_ids(),
...@@ -100,15 +103,17 @@ class RequestOutput: ...@@ -100,15 +103,17 @@ class RequestOutput:
outputs.append(output) outputs.append(output)
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt prompt = seq_group.prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids prompt_token_ids = seq_group.prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished() finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs, return cls(seq_group.request_id, prompt, prompt_token_ids,
finished) prompt_logprobs, outputs, finished)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, " f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"finished={self.finished})") f"finished={self.finished})")
...@@ -60,6 +60,12 @@ class SamplingParams: ...@@ -60,6 +60,12 @@ class SamplingParams:
tokens after the EOS token is generated. tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence. max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token. logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return
result includes the log probabilities on the `logprobs` most likely
tokens, as well the chosen tokens. The API will always return the
log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
skip_special_tokens: Whether to skip special tokens in the output. skip_special_tokens: Whether to skip special tokens in the output.
""" """
...@@ -80,6 +86,7 @@ class SamplingParams: ...@@ -80,6 +86,7 @@ class SamplingParams:
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
) -> None: ) -> None:
self.n = n self.n = n
...@@ -105,6 +112,7 @@ class SamplingParams: ...@@ -105,6 +112,7 @@ class SamplingParams:
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.logprobs = logprobs self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens self.skip_special_tokens = skip_special_tokens
self._verify_args() self._verify_args()
...@@ -142,6 +150,9 @@ class SamplingParams: ...@@ -142,6 +150,9 @@ class SamplingParams:
if self.logprobs is not None and self.logprobs < 0: if self.logprobs is not None and self.logprobs < 0:
raise ValueError( raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.") f"logprobs must be non-negative, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
def _verify_beam_search(self) -> None: def _verify_beam_search(self) -> None:
if self.best_of == 1: if self.best_of == 1:
...@@ -200,4 +211,5 @@ class SamplingParams: ...@@ -200,4 +211,5 @@ class SamplingParams:
f"ignore_eos={self.ignore_eos}, " f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}, " f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens})") f"skip_special_tokens={self.skip_special_tokens})")
...@@ -6,6 +6,9 @@ from typing import Dict, List, Optional, Union ...@@ -6,6 +6,9 @@ from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]]
class SequenceStatus(enum.Enum): class SequenceStatus(enum.Enum):
"""Status of a sequence.""" """Status of a sequence."""
...@@ -116,7 +119,7 @@ class Sequence: ...@@ -116,7 +119,7 @@ class Sequence:
self.block_size = block_size self.block_size = block_size
self.data = SequenceData(prompt_token_ids) self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = [] self.output_logprobs: SampleLogprobs = []
self.output_text = "" self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = [] self.logical_token_blocks: List[LogicalTokenBlock] = []
...@@ -196,7 +199,7 @@ class Sequence: ...@@ -196,7 +199,7 @@ class Sequence:
""" """
if seq_len is None: if seq_len is None:
seq_len = self.get_len() seq_len = self.get_len()
# Note: HF implementation does not count the EOS token # NOTE: HF implementation does not count the EOS token
# towards the length, we align with that here for testing. # towards the length, we align with that here for testing.
if (eos_token_id is not None if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id): and self.get_last_token_id() == eos_token_id):
...@@ -238,6 +241,19 @@ class SequenceGroup: ...@@ -238,6 +241,19 @@ class SequenceGroup:
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.arrival_time = arrival_time self.arrival_time = arrival_time
self.prompt_logprobs: Optional[PromptLogprobs] = None
@property
def prompt(self) -> str:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt
@property
def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
def get_max_num_running_seqs(self) -> int: def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
...@@ -370,6 +386,22 @@ class SequenceOutputs: ...@@ -370,6 +386,22 @@ class SequenceOutputs:
and self.logprobs == other.logprobs) and self.logprobs == other.logprobs)
class SequenceGroupOutputs:
"""The model outputs associated with a sequence group."""
def __init__(
self,
samples: List[SequenceOutputs],
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str:
return (f"SequenceGroupOutputs(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})")
# For each sequence group, we generate a list of SequenceOutputs object, # For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token. # each of which contains one possible candidate for the next token.
SamplerOutput = List[List[SequenceOutputs]] SamplerOutput = List[SequenceGroupOutputs]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment