Unverified Commit 0630d453 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[V1] Logprobs and prompt logprobs support (#9880)



This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.

New behavior:

- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: default avatarAndrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 538fab93
# SPDX-License-Identifier: Apache-2.0
import itertools
from dataclasses import dataclass
from typing import Dict, List, Optional
from vllm.logger import init_logger
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_ids_list_to_tokens)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
logger = init_logger(__name__)
@dataclass
class LogprobsProcessor:
# Tokenizer for this request
tokenizer: AnyTokenizer
# Logprobs for this request
logprobs: Optional[SampleLogprobs]
prompt_logprobs: Optional[PromptLogprobs]
cumulative_logprob: Optional[float]
num_logprobs: Optional[int]
num_prompt_logprobs: Optional[int]
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
request: EngineCoreRequest,
) -> "LogprobsProcessor":
num_logprobs = request.sampling_params.logprobs
num_prompt_logprobs = request.sampling_params.prompt_logprobs
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.),
logprobs=(None if num_logprobs is None else []),
# NOTE: logprob of first prompt token is None.
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
)
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
"""Update with sample logprobs from EngineCore.
Outer lists are only of len > 1 if EngineCore made
>1 tokens in prior step (e.g. in spec decoding).
Args:
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
"""
assert self.num_logprobs is not None
assert self.logprobs is not None
assert self.cumulative_logprob is not None
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
token_ids_lst):
# Detokenize (non-incrementally).
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer, token_ids)
# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
self.cumulative_logprob += sampled_token_logprob
# Update with the Logprob dictionary for this pos.
self.logprobs.append(
self._make_logprob_dict(
logprobs,
token_ids,
decoded_tokens,
rank,
self.num_logprobs,
))
def _update_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
) -> None:
"""Update with prompt logprobs from EngineCore.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
# Prompt logprobs are enabled.
assert self.num_prompt_logprobs is not None
assert self.prompt_logprobs is not None
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer,
token_ids.flatten().tolist())
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs))
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
Ensures correct RequestOutputKind.DELTA semantics
wherein all prompt logprobs are returned at once at
the end of prefill.
Returns:
None if prompt logprobs are disabled for this request.
List of all prompt logprobs, otherwise.
"""
plp = self.prompt_logprobs
if plp:
self.prompt_logprobs = []
return plp
@staticmethod
def _make_logprob_dict(
logprobs: List[float],
logprob_token_ids: List[int],
decoded_tokens: List[str],
rank: int,
num_logprobs: int,
) -> Dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
Dict[token id, Logprob]
"""
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank, ), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
logprob_token_ids, logprobs, ranks, decoded_tokens)
}
def update_from_output(self, output: EngineCoreOutput) -> None:
if output.new_logprobs is not None:
self._update_sample_logprobs(output.new_logprobs)
if output.new_prompt_logprobs_tensors is not None:
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)
...@@ -5,11 +5,12 @@ from dataclasses import dataclass ...@@ -5,11 +5,12 @@ from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.transformers_utils.detokenizer_utils import AnyTokenizer from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import (DetokenizerOutput, from vllm.v1.engine.detokenizer import IncrementalDetokenizer
IncrementalDetokenizer) from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.metrics.stats import IterationStats, RequestStateStats from vllm.v1.metrics.stats import IterationStats, RequestStateStats
...@@ -26,16 +27,20 @@ class RequestState: ...@@ -26,16 +27,20 @@ class RequestState:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
output_kind: RequestOutputKind,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: List[int], prompt_token_ids: List[int],
logprobs_processor: LogprobsProcessor,
detokenizer: IncrementalDetokenizer, detokenizer: IncrementalDetokenizer,
arrival_time: float, arrival_time: float,
queue: Optional[asyncio.Queue[RequestOutput]], queue: Optional[asyncio.Queue[RequestOutput]],
): ):
self.request_id = request_id self.request_id = request_id
self.output_kind = output_kind
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.prompt_len = len(prompt_token_ids) self.prompt_len = len(prompt_token_ids)
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer self.detokenizer = detokenizer
self.is_prefilling = True self.is_prefilling = True
self.queue = queue self.queue = queue
...@@ -51,8 +56,13 @@ class RequestState: ...@@ -51,8 +56,13 @@ class RequestState:
) -> "RequestState": ) -> "RequestState":
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
output_kind=request.sampling_params.output_kind,
prompt=request.prompt, prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
logprobs_processor=LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
),
detokenizer=IncrementalDetokenizer.from_new_request( detokenizer=IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer, tokenizer=tokenizer,
request=request, request=request,
...@@ -127,13 +137,8 @@ class OutputProcessor: ...@@ -127,13 +137,8 @@ class OutputProcessor:
batch to ensure system overheads are minimized. This is the batch to ensure system overheads are minimized. This is the
only function that should loop over EngineCoreOutputs. only function that should loop over EngineCoreOutputs.
If you need to touch every element of the batch, implement a If you need to touch every element of the batch, do it from
method called XXXClass.update_from_output() to be called within the loop below.
within the loop below. For examples, see:
* IterationStats.update_from_output()
* Detokenizer.update_from_output()
TODO(rob): add Protocol makes update_from_output explicit.
********************************************************** **********************************************************
""" """
...@@ -154,17 +159,37 @@ class OutputProcessor: ...@@ -154,17 +159,37 @@ class OutputProcessor:
req_state.is_prefilling, req_state.is_prefilling,
req_state.prompt_len, req_state.prompt_len,
req_state.stats) req_state.stats)
req_state.is_prefilling = False
# 2) Detokenize the token ids into text.
detokenizer_output = req_state.detokenizer.update_from_output(
engine_core_output)
# 3) Create and handle RequestOutput objects.
if detokenizer_output is not None:
request_output = self._make_request_output(
req_state, detokenizer_output)
new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
# TODO(andy): prompt logprobs + chunked prefill can
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state.is_prefilling = not new_token_ids
# 2) Detokenize the token ids into text and check for stop
# strings.
stop_reason = req_state.detokenizer.update(new_token_ids)
if stop_reason:
finish_reason = FinishReason.STOP
# 3) Compute sample and prompt logprobs for request,
# if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := self._make_request_output(
req_state, new_token_ids, finish_reason, stop_reason):
if req_state.queue is not None: if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate(). # AsyncLLM: put into queue for handling by generate().
req_state.queue.put_nowait(request_output) req_state.queue.put_nowait(request_output)
...@@ -174,18 +199,16 @@ class OutputProcessor: ...@@ -174,18 +199,16 @@ class OutputProcessor:
# Free completed requests. # Free completed requests.
if request_output.finished: if request_output.finished:
assert detokenizer_output.finish_reason is not None
self.request_states.pop(req_id) self.request_states.pop(req_id)
if not engine_core_output.finished: if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer # If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore. # detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id) reqs_to_abort.append(req_id)
# Track per-request stats # Track per-request stats.
assert finish_reason is not None
iteration_stats.update_from_finished_request( iteration_stats.update_from_finished_request(
detokenizer_output.finish_reason, request_output, finish_reason, request_output, req_state.stats)
req_state.stats)
return OutputProcessorOutput( return OutputProcessorOutput(
request_outputs=request_outputs, request_outputs=request_outputs,
...@@ -196,20 +219,47 @@ class OutputProcessor: ...@@ -196,20 +219,47 @@ class OutputProcessor:
@staticmethod @staticmethod
def _make_request_output( def _make_request_output(
request_state: RequestState, request_state: RequestState,
detokenizer_output: DetokenizerOutput, new_token_ids: List[int],
) -> RequestOutput: finish_reason: Optional[FinishReason],
stop_reason: Optional[str],
) -> Optional[RequestOutput]:
finished = finish_reason is not None
output_kind = request_state.output_kind
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if not finished and (request_state.is_prefilling
or output_kind == RequestOutputKind.FINAL_ONLY):
# Only the final output is required in FINAL_ONLY mode.
return None
detokenizer = request_state.detokenizer
logprobs_processor = request_state.logprobs_processor
delta = output_kind == RequestOutputKind.DELTA
logprobs = logprobs_processor.logprobs
if delta:
if logprobs:
logprobs = logprobs[-len(new_token_ids):]
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = logprobs_processor.pop_prompt_logprobs()
else:
prompt_logprobs = logprobs_processor.prompt_logprobs
request_output = RequestOutput.new( request_output = RequestOutput.new(
request_state.request_id, request_id=request_state.request_id,
request_state.prompt, prompt=request_state.prompt,
request_state.prompt_token_ids, prompt_token_ids=request_state.prompt_token_ids,
detokenizer_output.output_text, text=detokenizer.get_next_output_text(finished, delta),
detokenizer_output.token_ids, token_ids=new_token_ids if delta else detokenizer.output_token_ids,
detokenizer_output.finished, logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
cumulative_logprob=logprobs_processor.cumulative_logprob,
finished=finished,
) )
if detokenizer_output.finished: if finished:
completion_output = request_output.outputs[0] completion_output = request_output.outputs[0]
completion_output.finish_reason = str( completion_output.finish_reason = str(finish_reason)
detokenizer_output.finish_reason) completion_output.stop_reason = stop_reason
completion_output.stop_reason = detokenizer_output.stop_reason
return request_output return request_output
...@@ -33,6 +33,7 @@ class Processor: ...@@ -33,6 +33,7 @@ class Processor:
): ):
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -51,6 +52,37 @@ class Processor: ...@@ -51,6 +52,37 @@ class Processor:
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching cache_config.enable_prefix_caching
def _validate_logprobs(
self,
params: Union[SamplingParams, PoolingParams],
) -> None:
if not isinstance(params, SamplingParams):
return
max_logprobs = self.model_config.max_logprobs
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# Validate prompt logprobs.
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {params.prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# TODO(andy): enable this in follow up by recomputing.
if (params.prompt_logprobs is not None
and self.cache_config.enable_prefix_caching):
raise ValueError("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def process_inputs( def process_inputs(
self, self,
request_id: str, request_id: str,
...@@ -64,12 +96,11 @@ class Processor: ...@@ -64,12 +96,11 @@ class Processor:
) -> EngineCoreRequest: ) -> EngineCoreRequest:
# TODO(woosuk): Support pooling models. # TODO(woosuk): Support pooling models.
# TODO(woosuk): Check max_logprobs
# TODO(woosuk): Support encoder-decoder models. # TODO(woosuk): Support encoder-decoder models.
if lora_request is not None and not self.lora_config: self._validate_logprobs(params)
raise ValueError(f"Got lora_request {lora_request} but LoRA is " self._validate_lora(lora_request)
"not enabled!")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
assert priority == 0, "vLLM V1 does not support priority at the moment." assert priority == 0, "vLLM V1 does not support priority at the moment."
......
...@@ -60,14 +60,17 @@ class IterationStats: ...@@ -60,14 +60,17 @@ class IterationStats:
self.num_generation_tokens += num_new_generation_tokens self.num_generation_tokens += num_new_generation_tokens
if is_prefilling: if is_prefilling:
# This relies on the invariant that EngineCore does # TODO(andy): we used to assert that num_new_generation_tokens
# not stream outputs for partially completed prefills # > 0 with an invariant that EngineCore does not stream outputs
# (scheduler.update_from_output makes EngineCoreOutput # for partially completed prefills (scheduler.update_from_output
# iff num_computed_tokens == num_tokens). # makes EngineCoreOutput iff num_computed_tokens == num_tokens).
assert (num_new_generation_tokens > 0) # When prompt logprobs are enabled, we currently stream out the
self.num_prompt_tokens += prompt_len # partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
self.time_to_first_tokens_iter.append(last_token_latency) # this assertion / invariant.
if num_new_generation_tokens > 0:
self.num_prompt_tokens += prompt_len
self.time_to_first_tokens_iter.append(last_token_latency)
else: else:
self.time_per_output_tokens_iter.append(last_token_latency) self.time_per_output_tokens_iter.append(last_token_latency)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, NamedTuple, Optional
import torch import torch
@dataclass class LogprobsLists(NamedTuple):
class SamplerOutput:
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: List[List[int]]
# [num_reqs, max_num_logprobs + 1]
logprobs: List[List[float]]
# [num_reqs] # [num_reqs]
sampled_token_ids: torch.Tensor sampled_token_ranks: List[int]
def slice(self, start: int, end: int):
return LogprobsLists(
self.logprob_token_ids[start:end],
self.logprobs[start:end],
self.sampled_token_ranks[start:end],
)
class LogprobsTensors(NamedTuple):
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor] logprob_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
logprobs: Optional[torch.Tensor] logprobs: torch.Tensor
# [num_reqs]
selected_token_ranks: torch.Tensor
# TODO: Support prompt logprobs. def tolists(self):
prompt_logprob_token_ids: Optional[torch.Tensor] return LogprobsLists(
prompt_logprobs: Optional[torch.Tensor] self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
)
@dataclass
class SamplerOutput:
# [num_reqs]
sampled_token_ids: torch.Tensor
logprobs_tensors: Optional[LogprobsTensors]
# ModelRunnerOutput is serialized and sent to the scheduler process. # ModelRunnerOutput is serialized and sent to the scheduler process.
...@@ -36,6 +62,12 @@ class ModelRunnerOutput: ...@@ -36,6 +62,12 @@ class ModelRunnerOutput:
sampled_token_ids: List[int] sampled_token_ids: List[int]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
logprob_token_ids_cpu: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
logprobs_cpu: Optional[torch.Tensor] # [num_reqs]
logprobs: Optional[LogprobsLists]
# req_id -> (token_ids, logprobs, ranks)
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: Dict[str, LogprobsTensors]
...@@ -20,7 +20,8 @@ class SamplingMetadata: ...@@ -20,7 +20,8 @@ class SamplingMetadata:
generators: Dict[int, torch.Generator] generators: Dict[int, torch.Generator]
max_num_logprobs: int # None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
no_penalties: bool no_penalties: bool
prompt_token_ids: Optional[torch.Tensor] prompt_token_ids: Optional[torch.Tensor]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.v1.outputs import SamplerOutput from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_all_penalties, from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties) apply_min_token_penalties)
...@@ -25,20 +24,16 @@ class Sampler(nn.Module): ...@@ -25,20 +24,16 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
needs_logprobs = sampling_metadata.max_num_logprobs > 0
if needs_logprobs: # NOTE(woosuk): Use the original logits (before any penalties or
# NOTE(woosuk): Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs.
# temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that
# This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling).
# is used for sampling (after penalties and temperature scaling). # TODO(rob): provide option for logprobs post sampling.
# NOTE: We compute logprobs first because the below ops may # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
# modify the logits tensor in-place (and we don't want to clone num_logprobs = sampling_metadata.max_num_logprobs
# the logits tensor for memory efficiency). if num_logprobs is not None:
topk_logprobs, topk_indices = self.get_topk_logprobs( raw_logprobs = self.compute_logprobs(logits)
logits, sampling_metadata)
else:
topk_logprobs = None
topk_indices = None
# Use float32 for the logits. # Use float32 for the logits.
logits = logits.to(torch.float32) logits = logits.to(torch.float32)
...@@ -48,15 +43,19 @@ class Sampler(nn.Module): ...@@ -48,15 +43,19 @@ class Sampler(nn.Module):
logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token. # Sample the next token.
sampled = self.sample(logits, sampling_metadata) sampled = self.sample(logits, sampling_metadata)
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
logprobs_tensors = None if num_logprobs is None else \
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
# Use int32 to reduce the tensor size. # Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32) sampled = sampled.to(torch.int32)
# These are GPU tensors.
sampler_output = SamplerOutput( sampler_output = SamplerOutput(
sampled_token_ids=sampled, sampled_token_ids=sampled,
logprob_token_ids=topk_indices, logprobs_tensors=logprobs_tensors,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
prompt_logprobs=None,
) )
return sampler_output return sampler_output
...@@ -103,19 +102,52 @@ class Sampler(nn.Module): ...@@ -103,19 +102,52 @@ class Sampler(nn.Module):
) )
return sampled return sampled
def get_topk_logprobs( def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self, self,
logits: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, num_logprobs: int,
) -> Tuple[torch.Tensor, torch.Tensor]: token_ids: torch.Tensor,
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32) ) -> LogprobsTensors:
# FIXME: Mask the sampled token_id, get topk logprobs, """
# and concatenate the topk with the sampled token_id. Gather logprobs for topk and sampled/prompt token.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1) Args:
logits: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size. # Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32) indices = indices.to(torch.int32)
return topk_logprobs, topk_indices
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_penalties( def apply_penalties(
self, self,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pickle import pickle
from typing import Any
import torch
from msgspec import msgpack
CUSTOM_TYPE_CODE_PICKLE = 1
class PickleEncoder: class PickleEncoder:
def encode(self, obj): def encode(self, obj: Any):
return pickle.dumps(obj) return pickle.dumps(obj)
def decode(self, data): def decode(self, data: Any):
return pickle.loads(data) return pickle.loads(data)
class MsgpackEncoder:
"""Encoder with custom torch tensor serialization."""
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
def encode(self, obj: Any) -> bytes:
return self.encoder.encode(obj)
def encode_into(self, obj: Any, buf: bytearray) -> None:
self.encoder.encode_into(obj, buf)
class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""
def __init__(self, t: Any):
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)
def decode(self, obj: Any):
return self.decoder.decode(obj)
def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))
raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_CODE_PICKLE:
return torch.from_numpy(pickle.loads(data))
raise NotImplementedError(f"Extension type code {code} is not supported")
...@@ -176,7 +176,9 @@ class InputBatch: ...@@ -176,7 +176,9 @@ class InputBatch:
self.generators: Dict[int, torch.Generator] = {} self.generators: Dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {} self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set() # NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {}
def add_request( def add_request(
self, self,
...@@ -238,11 +240,10 @@ class InputBatch: ...@@ -238,11 +240,10 @@ class InputBatch:
if request.generator is not None: if request.generator is not None:
self.generators[req_index] = request.generator self.generators[req_index] = request.generator
num_logprobs = sampling_params.logprobs if sampling_params.logprobs is not None:
if num_logprobs is not None and num_logprobs > 0: self.num_logprobs[req_id] = sampling_params.logprobs
self.num_logprobs[req_id] = num_logprobs if sampling_params.prompt_logprobs is not None:
if sampling_params.prompt_logprobs: self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
self.prompt_logprob_reqs.add(req_id)
# Add request lora ID # Add request lora ID
if request.lora_request: if request.lora_request:
...@@ -272,7 +273,7 @@ class InputBatch: ...@@ -272,7 +273,7 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None) self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None) self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id) self.num_prompt_logprobs.pop(req_id, None)
# LoRA # LoRA
lora_id = self.request_lora_mapping[req_index] lora_id = self.request_lora_mapping[req_index]
...@@ -297,7 +298,7 @@ class InputBatch: ...@@ -297,7 +298,7 @@ class InputBatch:
self.repetition_penalties_reqs.clear() self.repetition_penalties_reqs.clear()
self.generators.clear() self.generators.clear()
self.num_logprobs.clear() self.num_logprobs.clear()
self.prompt_logprob_reqs.clear() self.num_prompt_logprobs.clear()
self.request_lora_mapping.fill(0) self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear() self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear() self.lora_id_to_request_ids.clear()
...@@ -489,13 +490,9 @@ class InputBatch: ...@@ -489,13 +490,9 @@ class InputBatch:
and len(self.repetition_penalties_reqs) == 0) and len(self.repetition_penalties_reqs) == 0)
@property @property
def max_num_logprobs(self) -> int: def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else 0 return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_logprob(self) -> bool:
return len(self.num_logprobs) == 0
@property @property
def no_prompt_logprob(self) -> bool: def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0 return not self.num_prompt_logprobs
...@@ -29,7 +29,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget ...@@ -29,7 +29,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -804,8 +804,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -804,8 +804,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[:num_scheduled_tokens]
hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
# Sample the next token and get logprobs if needed. # Sample the next token and get logprobs if needed.
sampling_metadata = self._prepare_sampling(batch_changed) sampling_metadata = self._prepare_sampling(batch_changed)
...@@ -818,7 +818,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -818,7 +818,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the requests one by one. Optimize. # the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): for i, req_id in enumerate( # type: ignore[assignment]
self.input_batch.req_ids[:num_reqs]):
assert req_id is not None assert req_id is not None
req_state = self.requests[req_id] req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens + seq_len = (req_state.num_computed_tokens +
...@@ -847,27 +848,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -847,27 +848,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: GPU -> CPU Sync happens here. # NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point. # Move as many CPU operations as possible before this sync point.
sampled_token_ids = sampler_output.sampled_token_ids.tolist() sampled_token_ids = sampler_output.sampled_token_ids.tolist()
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states,
scheduler_output,
)
# Update with the actual token ids # Update with the actual token ids
for i, req_state, seq_len in request_seq_lens: for i, req_state, seq_len in request_seq_lens:
token_id = sampled_token_ids[i] token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids[-1] = token_id req_state.output_token_ids[-1] = token_id
if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
else:
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
if sampler_output.logprobs is None:
logprobs = None
else:
logprobs = sampler_output.logprobs.cpu()
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids, logprobs=logprobs_lists,
logprobs_cpu=logprobs, prompt_logprobs_dict=prompt_logprobs_dict,
) )
return model_runner_output return model_runner_output
...@@ -886,6 +888,76 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -886,6 +888,76 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30)) self.model_memory_usage / float(2**30))
def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
) -> Dict[str, LogprobsTensors]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# Get metadata for this request.
request = self.requests[req_id]
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)
# Determine number of logits to retrieve.
start_tok = request.num_computed_tokens + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens < num_remaining_tokens:
# This is a chunk, more tokens remain.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(req_id)
# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
# then there is prompt logprob generated for each index.
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc_np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None)
# Get the "target" tokens for each index. For prompt at index i,
# the token at prompt index i+1 is the "sampled" token we want
# to gather the logprob for.
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
# Compute prompt logprobs.
logprobs = self.model.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids)
# Transfer GPU->CPU async.
prompt_logprobs_dict[req_id] = LogprobsTensors(
token_ids.to("cpu", non_blocking=True),
logprobs.to("cpu", non_blocking=True),
ranks.to("cpu", non_blocking=True),
)
# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
del num_prompt_logprobs_dict[req_id]
# Must synchronize the non-blocking GPU->CPU transfers.
torch.cuda.synchronize()
return prompt_logprobs_dict
@torch.inference_mode() @torch.inference_mode()
def _dummy_run( def _dummy_run(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment