Unverified Commit 0630d453 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[V1] Logprobs and prompt logprobs support (#9880)



This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.

New behavior:

- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: default avatarAndrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 538fab93
# SPDX-License-Identifier: Apache-2.0
import itertools
from dataclasses import dataclass
from typing import Dict, List, Optional
from vllm.logger import init_logger
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_ids_list_to_tokens)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
logger = init_logger(__name__)
@dataclass
class LogprobsProcessor:
# Tokenizer for this request
tokenizer: AnyTokenizer
# Logprobs for this request
logprobs: Optional[SampleLogprobs]
prompt_logprobs: Optional[PromptLogprobs]
cumulative_logprob: Optional[float]
num_logprobs: Optional[int]
num_prompt_logprobs: Optional[int]
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
request: EngineCoreRequest,
) -> "LogprobsProcessor":
num_logprobs = request.sampling_params.logprobs
num_prompt_logprobs = request.sampling_params.prompt_logprobs
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.),
logprobs=(None if num_logprobs is None else []),
# NOTE: logprob of first prompt token is None.
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
)
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
"""Update with sample logprobs from EngineCore.
Outer lists are only of len > 1 if EngineCore made
>1 tokens in prior step (e.g. in spec decoding).
Args:
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
"""
assert self.num_logprobs is not None
assert self.logprobs is not None
assert self.cumulative_logprob is not None
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
token_ids_lst):
# Detokenize (non-incrementally).
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer, token_ids)
# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
self.cumulative_logprob += sampled_token_logprob
# Update with the Logprob dictionary for this pos.
self.logprobs.append(
self._make_logprob_dict(
logprobs,
token_ids,
decoded_tokens,
rank,
self.num_logprobs,
))
def _update_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
) -> None:
"""Update with prompt logprobs from EngineCore.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
# Prompt logprobs are enabled.
assert self.num_prompt_logprobs is not None
assert self.prompt_logprobs is not None
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer,
token_ids.flatten().tolist())
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs))
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
Ensures correct RequestOutputKind.DELTA semantics
wherein all prompt logprobs are returned at once at
the end of prefill.
Returns:
None if prompt logprobs are disabled for this request.
List of all prompt logprobs, otherwise.
"""
plp = self.prompt_logprobs
if plp:
self.prompt_logprobs = []
return plp
@staticmethod
def _make_logprob_dict(
logprobs: List[float],
logprob_token_ids: List[int],
decoded_tokens: List[str],
rank: int,
num_logprobs: int,
) -> Dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
Dict[token id, Logprob]
"""
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank, ), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
logprob_token_ids, logprobs, ranks, decoded_tokens)
}
def update_from_output(self, output: EngineCoreOutput) -> None:
if output.new_logprobs is not None:
self._update_sample_logprobs(output.new_logprobs)
if output.new_prompt_logprobs_tensors is not None:
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)
......@@ -5,11 +5,12 @@ from dataclasses import dataclass
from typing import Dict, List, Optional
from vllm.outputs import RequestOutput
from vllm.transformers_utils.detokenizer_utils import AnyTokenizer
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.engine.detokenizer import (DetokenizerOutput,
IncrementalDetokenizer)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
......@@ -26,16 +27,20 @@ class RequestState:
def __init__(
self,
request_id: str,
output_kind: RequestOutputKind,
prompt: Optional[str],
prompt_token_ids: List[int],
logprobs_processor: LogprobsProcessor,
detokenizer: IncrementalDetokenizer,
arrival_time: float,
queue: Optional[asyncio.Queue[RequestOutput]],
):
self.request_id = request_id
self.output_kind = output_kind
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_len = len(prompt_token_ids)
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer
self.is_prefilling = True
self.queue = queue
......@@ -51,8 +56,13 @@ class RequestState:
) -> "RequestState":
return cls(
request_id=request.request_id,
output_kind=request.sampling_params.output_kind,
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
logprobs_processor=LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
),
detokenizer=IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
......@@ -127,13 +137,8 @@ class OutputProcessor:
batch to ensure system overheads are minimized. This is the
only function that should loop over EngineCoreOutputs.
If you need to touch every element of the batch, implement a
method called XXXClass.update_from_output() to be called
within the loop below. For examples, see:
* IterationStats.update_from_output()
* Detokenizer.update_from_output()
TODO(rob): add Protocol makes update_from_output explicit.
If you need to touch every element of the batch, do it from
within the loop below.
**********************************************************
"""
......@@ -154,17 +159,37 @@ class OutputProcessor:
req_state.is_prefilling,
req_state.prompt_len,
req_state.stats)
req_state.is_prefilling = False
# 2) Detokenize the token ids into text.
detokenizer_output = req_state.detokenizer.update_from_output(
engine_core_output)
# 3) Create and handle RequestOutput objects.
if detokenizer_output is not None:
request_output = self._make_request_output(
req_state, detokenizer_output)
new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
# TODO(andy): prompt logprobs + chunked prefill can
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state.is_prefilling = not new_token_ids
# 2) Detokenize the token ids into text and check for stop
# strings.
stop_reason = req_state.detokenizer.update(new_token_ids)
if stop_reason:
finish_reason = FinishReason.STOP
# 3) Compute sample and prompt logprobs for request,
# if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := self._make_request_output(
req_state, new_token_ids, finish_reason, stop_reason):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put_nowait(request_output)
......@@ -174,18 +199,16 @@ class OutputProcessor:
# Free completed requests.
if request_output.finished:
assert detokenizer_output.finish_reason is not None
self.request_states.pop(req_id)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
# Track per-request stats
# Track per-request stats.
assert finish_reason is not None
iteration_stats.update_from_finished_request(
detokenizer_output.finish_reason, request_output,
req_state.stats)
finish_reason, request_output, req_state.stats)
return OutputProcessorOutput(
request_outputs=request_outputs,
......@@ -196,20 +219,47 @@ class OutputProcessor:
@staticmethod
def _make_request_output(
request_state: RequestState,
detokenizer_output: DetokenizerOutput,
) -> RequestOutput:
new_token_ids: List[int],
finish_reason: Optional[FinishReason],
stop_reason: Optional[str],
) -> Optional[RequestOutput]:
finished = finish_reason is not None
output_kind = request_state.output_kind
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if not finished and (request_state.is_prefilling
or output_kind == RequestOutputKind.FINAL_ONLY):
# Only the final output is required in FINAL_ONLY mode.
return None
detokenizer = request_state.detokenizer
logprobs_processor = request_state.logprobs_processor
delta = output_kind == RequestOutputKind.DELTA
logprobs = logprobs_processor.logprobs
if delta:
if logprobs:
logprobs = logprobs[-len(new_token_ids):]
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = logprobs_processor.pop_prompt_logprobs()
else:
prompt_logprobs = logprobs_processor.prompt_logprobs
request_output = RequestOutput.new(
request_state.request_id,
request_state.prompt,
request_state.prompt_token_ids,
detokenizer_output.output_text,
detokenizer_output.token_ids,
detokenizer_output.finished,
request_id=request_state.request_id,
prompt=request_state.prompt,
prompt_token_ids=request_state.prompt_token_ids,
text=detokenizer.get_next_output_text(finished, delta),
token_ids=new_token_ids if delta else detokenizer.output_token_ids,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
cumulative_logprob=logprobs_processor.cumulative_logprob,
finished=finished,
)
if detokenizer_output.finished:
if finished:
completion_output = request_output.outputs[0]
completion_output.finish_reason = str(
detokenizer_output.finish_reason)
completion_output.stop_reason = detokenizer_output.stop_reason
completion_output.finish_reason = str(finish_reason)
completion_output.stop_reason = stop_reason
return request_output
......@@ -33,6 +33,7 @@ class Processor:
):
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.tokenizer = tokenizer
......@@ -51,6 +52,37 @@ class Processor:
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
def _validate_logprobs(
self,
params: Union[SamplingParams, PoolingParams],
) -> None:
if not isinstance(params, SamplingParams):
return
max_logprobs = self.model_config.max_logprobs
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# Validate prompt logprobs.
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {params.prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# TODO(andy): enable this in follow up by recomputing.
if (params.prompt_logprobs is not None
and self.cache_config.enable_prefix_caching):
raise ValueError("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def process_inputs(
self,
request_id: str,
......@@ -64,12 +96,11 @@ class Processor:
) -> EngineCoreRequest:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Check max_logprobs
# TODO(woosuk): Support encoder-decoder models.
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
self._validate_logprobs(params)
self._validate_lora(lora_request)
if arrival_time is None:
arrival_time = time.time()
assert priority == 0, "vLLM V1 does not support priority at the moment."
......
......@@ -60,14 +60,17 @@ class IterationStats:
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
# This relies on the invariant that EngineCore does
# not stream outputs for partially completed prefills
# (scheduler.update_from_output makes EngineCoreOutput
# iff num_computed_tokens == num_tokens).
assert (num_new_generation_tokens > 0)
self.num_prompt_tokens += prompt_len
self.time_to_first_tokens_iter.append(last_token_latency)
# TODO(andy): we used to assert that num_new_generation_tokens
# > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
# When prompt logprobs are enabled, we currently stream out the
# partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant.
if num_new_generation_tokens > 0:
self.num_prompt_tokens += prompt_len
self.time_to_first_tokens_iter.append(last_token_latency)
else:
self.time_per_output_tokens_iter.append(last_token_latency)
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, List, NamedTuple, Optional
import torch
@dataclass
class SamplerOutput:
class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: List[List[int]]
# [num_reqs, max_num_logprobs + 1]
logprobs: List[List[float]]
# [num_reqs]
sampled_token_ids: torch.Tensor
sampled_token_ranks: List[int]
def slice(self, start: int, end: int):
return LogprobsLists(
self.logprob_token_ids[start:end],
self.logprobs[start:end],
self.sampled_token_ranks[start:end],
)
class LogprobsTensors(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]
logprob_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprobs: Optional[torch.Tensor]
logprobs: torch.Tensor
# [num_reqs]
selected_token_ranks: torch.Tensor
# TODO: Support prompt logprobs.
prompt_logprob_token_ids: Optional[torch.Tensor]
prompt_logprobs: Optional[torch.Tensor]
def tolists(self):
return LogprobsLists(
self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
)
@dataclass
class SamplerOutput:
# [num_reqs]
sampled_token_ids: torch.Tensor
logprobs_tensors: Optional[LogprobsTensors]
# ModelRunnerOutput is serialized and sent to the scheduler process.
......@@ -36,6 +62,12 @@ class ModelRunnerOutput:
sampled_token_ids: List[int]
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids_cpu: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1]
logprobs_cpu: Optional[torch.Tensor]
# [num_reqs]
logprobs: Optional[LogprobsLists]
# req_id -> (token_ids, logprobs, ranks)
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: Dict[str, LogprobsTensors]
......@@ -20,7 +20,8 @@ class SamplingMetadata:
generators: Dict[int, torch.Generator]
max_num_logprobs: int
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
......
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
from typing import Tuple
import torch
import torch.nn as nn
from vllm.v1.outputs import SamplerOutput
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
......@@ -25,20 +24,16 @@ class Sampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
needs_logprobs = sampling_metadata.max_num_logprobs > 0
if needs_logprobs:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# NOTE: We compute logprobs first because the below ops may
# modify the logits tensor in-place (and we don't want to clone
# the logits tensor for memory efficiency).
topk_logprobs, topk_indices = self.get_topk_logprobs(
logits, sampling_metadata)
else:
topk_logprobs = None
topk_indices = None
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# TODO(rob): provide option for logprobs post sampling.
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits)
# Use float32 for the logits.
logits = logits.to(torch.float32)
......@@ -48,15 +43,19 @@ class Sampler(nn.Module):
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
logprobs_tensors = None if num_logprobs is None else \
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# These are GPU tensors.
sampler_output = SamplerOutput(
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
prompt_logprobs=None,
logprobs_tensors=logprobs_tensors,
)
return sampler_output
......@@ -103,19 +102,52 @@ class Sampler(nn.Module):
)
return sampled
def get_topk_logprobs(
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logits: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
return topk_logprobs, topk_indices
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_penalties(
self,
......
# SPDX-License-Identifier: Apache-2.0
import pickle
from typing import Any
import torch
from msgspec import msgpack
CUSTOM_TYPE_CODE_PICKLE = 1
class PickleEncoder:
def encode(self, obj):
def encode(self, obj: Any):
return pickle.dumps(obj)
def decode(self, data):
def decode(self, data: Any):
return pickle.loads(data)
class MsgpackEncoder:
"""Encoder with custom torch tensor serialization."""
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
def encode(self, obj: Any) -> bytes:
return self.encoder.encode(obj)
def encode_into(self, obj: Any, buf: bytearray) -> None:
self.encoder.encode_into(obj, buf)
class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""
def __init__(self, t: Any):
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)
def decode(self, obj: Any):
return self.decoder.decode(obj)
def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))
raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_CODE_PICKLE:
return torch.from_numpy(pickle.loads(data))
raise NotImplementedError(f"Extension type code {code} is not supported")
......@@ -176,7 +176,9 @@ class InputBatch:
self.generators: Dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set()
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {}
def add_request(
self,
......@@ -238,11 +240,10 @@ class InputBatch:
if request.generator is not None:
self.generators[req_index] = request.generator
num_logprobs = sampling_params.logprobs
if num_logprobs is not None and num_logprobs > 0:
self.num_logprobs[req_id] = num_logprobs
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_id)
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
# Add request lora ID
if request.lora_request:
......@@ -272,7 +273,7 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
self.num_prompt_logprobs.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
......@@ -297,7 +298,7 @@ class InputBatch:
self.repetition_penalties_reqs.clear()
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
self.num_prompt_logprobs.clear()
self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear()
......@@ -489,13 +490,9 @@ class InputBatch:
and len(self.repetition_penalties_reqs) == 0)
@property
def max_num_logprobs(self) -> int:
return max(self.num_logprobs.values()) if self.num_logprobs else 0
@property
def no_logprob(self) -> bool:
return len(self.num_logprobs) == 0
def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0
return not self.num_prompt_logprobs
......@@ -29,7 +29,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
......@@ -804,8 +804,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:num_scheduled_tokens]
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
# Sample the next token and get logprobs if needed.
sampling_metadata = self._prepare_sampling(batch_changed)
......@@ -818,7 +818,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
for i, req_id in enumerate( # type: ignore[assignment]
self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
......@@ -847,27 +848,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states,
scheduler_output,
)
# Update with the actual token ids
for i, req_state, seq_len in request_seq_lens:
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids[-1] = token_id
if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
else:
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
if sampler_output.logprobs is None:
logprobs = None
else:
logprobs = sampler_output.logprobs.cpu()
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids,
logprobs_cpu=logprobs,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)
return model_runner_output
......@@ -886,6 +888,76 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
) -> Dict[str, LogprobsTensors]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# Get metadata for this request.
request = self.requests[req_id]
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)
# Determine number of logits to retrieve.
start_tok = request.num_computed_tokens + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens < num_remaining_tokens:
# This is a chunk, more tokens remain.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(req_id)
# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
# then there is prompt logprob generated for each index.
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc_np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None)
# Get the "target" tokens for each index. For prompt at index i,
# the token at prompt index i+1 is the "sampled" token we want
# to gather the logprob for.
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
# Compute prompt logprobs.
logprobs = self.model.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids)
# Transfer GPU->CPU async.
prompt_logprobs_dict[req_id] = LogprobsTensors(
token_ids.to("cpu", non_blocking=True),
logprobs.to("cpu", non_blocking=True),
ranks.to("cpu", non_blocking=True),
)
# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
del num_prompt_logprobs_dict[req_id]
# Must synchronize the non-blocking GPU->CPU transfers.
torch.cuda.synchronize()
return prompt_logprobs_dict
@torch.inference_mode()
def _dummy_run(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment