Unverified Commit 799397ee authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Support embedding models in V1 (#16188)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarMax de Bayser <maxdebayser@gmail.com>
Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 49599150
......@@ -15,7 +15,7 @@ from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
......@@ -221,7 +221,7 @@ class LLMEngine:
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]:
def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
......
......@@ -38,6 +38,7 @@ class LogprobsProcessor:
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "LogprobsProcessor":
assert request.sampling_params is not None
num_logprobs = request.sampling_params.logprobs
num_prompt_logprobs = request.sampling_params.prompt_logprobs
return cls(
......
......@@ -4,9 +4,12 @@
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
from vllm.outputs import CompletionOutput, RequestOutput
import torch
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
......@@ -29,20 +32,22 @@ class RequestOutputCollector:
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[Union[RequestOutput, Exception]] = None
self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
Exception]] = None
self.ready = asyncio.Event()
def put(self, output: Union[RequestOutput, Exception]) -> None:
def put(self, output: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif isinstance(self.output, RequestOutput):
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
async def get(self) -> RequestOutput:
async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
"""Get operation blocks on put event."""
while (output := self.output) is None:
await self.ready.wait()
......@@ -52,7 +57,8 @@ class RequestOutputCollector:
raise output
return output
def get_nowait(self) -> Optional[RequestOutput]:
def get_nowait(
self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
"""Non-blocking get operation."""
output = self.output
if output is not None:
......@@ -66,7 +72,7 @@ class RequestOutputCollector:
@dataclass
class OutputProcessorOutput:
request_outputs: list[RequestOutput]
request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
reqs_to_abort: list[str]
......@@ -81,8 +87,8 @@ class RequestState:
output_kind: RequestOutputKind,
prompt: Optional[str],
prompt_token_ids: list[int],
logprobs_processor: LogprobsProcessor,
detokenizer: IncrementalDetokenizer,
logprobs_processor: Optional[LogprobsProcessor],
detokenizer: Optional[IncrementalDetokenizer],
max_tokens_param: Optional[int],
arrival_time: float,
queue: Optional[RequestOutputCollector],
......@@ -116,27 +122,39 @@ class RequestState:
queue: Optional[RequestOutputCollector],
log_stats: bool,
) -> "RequestState":
if not request.sampling_params.detokenize:
tokenizer = None
if sampling_params := request.sampling_params:
if not sampling_params.detokenize:
tokenizer = None
output_kind = sampling_params.output_kind
logprobs_processor = LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
)
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
)
max_tokens_param = sampling_params.max_tokens
else:
logprobs_processor = None
detokenizer = None
max_tokens_param = None
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind
return cls(
request_id=request.request_id,
parent_req=parent_req,
request_index=request_index,
lora_name=(request.lora_request.name
if request.lora_request is not None else None),
output_kind=request.sampling_params.output_kind,
output_kind=output_kind,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
logprobs_processor=LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
),
detokenizer=IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
),
max_tokens_param=(request.sampling_params.max_tokens if
request.sampling_params is not None else None),
logprobs_processor=logprobs_processor,
detokenizer=detokenizer,
max_tokens_param=max_tokens_param,
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
......@@ -145,11 +163,12 @@ class RequestState:
def make_request_output(
self,
new_token_ids: list[int],
pooling_output: Optional[torch.Tensor],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Optional[RequestOutput]:
) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
......@@ -158,15 +177,20 @@ class RequestState:
# Only the final output is required in FINAL_ONLY mode.
return None
completion_output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason)
request_id = self.request_id
if pooling_output is not None:
return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)],
finished)
output = self._new_completion_output(new_token_ids, finish_reason,
stop_reason)
if self.parent_req is None:
outputs = [completion_output]
outputs = [output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, completion_output)
request_id, output)
if not outputs:
return None
......@@ -176,12 +200,21 @@ class RequestState:
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
outputs: Union[list[CompletionOutput], list[PoolingOutput]],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> RequestOutput:
) -> Union[RequestOutput, PoolingRequestOutput]:
if isinstance(outputs[0], PoolingOutput):
assert len(outputs) == 1
return PoolingRequestOutput(
request_id=request_id,
outputs=outputs[0],
prompt_token_ids=self.prompt_token_ids,
finished=finished,
)
assert self.logprobs_processor is not None
if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
......@@ -193,7 +226,7 @@ class RequestState:
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=outputs,
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens,
......@@ -206,6 +239,8 @@ class RequestState:
stop_reason: Union[int, str, None],
) -> CompletionOutput:
assert self.detokenizer is not None
assert self.logprobs_processor is not None
finished = finish_reason is not None
delta = self.output_kind == RequestOutputKind.DELTA
......@@ -228,6 +263,13 @@ class RequestState:
finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None)
def _new_pooling_output(
self,
pooling_output: torch.Tensor,
) -> PoolingOutput:
return PoolingOutput(data=pooling_output)
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
......@@ -326,7 +368,8 @@ class OutputProcessor:
within the loop below.
"""
request_outputs: list[RequestOutput] = []
request_outputs: Union[list[RequestOutput],
list[PoolingRequestOutput]] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
......@@ -341,25 +384,31 @@ class OutputProcessor:
iteration_stats)
new_token_ids = engine_core_output.new_token_ids
pooling_output = engine_core_output.pooling_output
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request, if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
if pooling_output is None:
assert req_state.detokenizer is not None
assert req_state.logprobs_processor is not None
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request,
# if required.
req_state.logprobs_processor.update_from_output(
engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason,
new_token_ids, pooling_output, finish_reason, stop_reason,
kv_transfer_params, num_cached_tokens):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
......
......@@ -136,8 +136,8 @@ class Processor:
Should raise ValueError if unsupported for API Server.
"""
if not isinstance(params, SamplingParams):
raise ValueError("V1 does not yet support Pooling models.")
if isinstance(params, PoolingParams):
return
self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request)
......@@ -263,18 +263,22 @@ class Processor:
if encoder_inputs is not None:
raise NotImplementedError
assert isinstance(params, SamplingParams)
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
sampling_params = None
pooling_params = None
if isinstance(params, SamplingParams):
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
else:
pooling_params = params.clone()
# Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
......@@ -331,6 +335,7 @@ class Processor:
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
......
......@@ -481,8 +481,9 @@ class PrometheusStatLogger(StatLoggerBase):
finished_request.num_prompt_tokens)
self.histogram_num_generation_tokens_request.observe(
finished_request.num_generation_tokens)
self.histogram_max_tokens_request.observe(
finished_request.max_tokens_param)
if finished_request.max_tokens_param:
self.histogram_max_tokens_request.observe(
finished_request.max_tokens_param)
if self.gauge_lora_info is not None:
running_lora_adapters = \
......
......@@ -106,7 +106,6 @@ class IterationStats:
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
assert num_new_generation_tokens > 0
self.num_prompt_tokens += prompt_len
first_token_latency = self._time_since(req_stats.arrival_time)
......
......@@ -101,6 +101,9 @@ class ModelRunnerOutput:
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
# [num_reqs, hidden_size]
pooler_output: list[Optional[torch.Tensor]]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
......@@ -112,5 +115,6 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
finished_sending=None,
finished_recving=None)
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.pooling_params import PoolingParams
@dataclass
class PoolingMetadata:
"""Tensors for pooling."""
prompt_lens: torch.Tensor
prompt_token_ids: Optional[torch.Tensor]
pooling_params: list[PoolingParams]
......@@ -5,6 +5,7 @@ import enum
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
......@@ -25,7 +26,8 @@ class Request:
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams,
sampling_params: Optional[SamplingParams],
pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int],
client_index: int = 0,
lora_request: Optional["LoRARequest"] = None,
......@@ -35,18 +37,35 @@ class Request:
self.request_id = request_id
self.client_index = client_index
self.sampling_params = sampling_params
self.pooling_params = pooling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = structured_output_request
self.status = (RequestStatus.WAITING_FOR_FSM
if sampling_params.guided_decoding is not None else
RequestStatus.WAITING)
self.status = RequestStatus.WAITING
if sampling_params and sampling_params.guided_decoding is not None:
self.status = RequestStatus.WAITING_FOR_FSM
self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
# P/D: Connector-specific KV transfer parameters.
self.kv_transfer_params: Optional[dict[str, Any]] = None
if pooling_params is not None:
self.max_tokens = 1
elif sampling_params is not None:
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
if sampling_params.guided_decoding is not None:
self.status = RequestStatus.WAITING_FOR_FSM
if sampling_params.extra_args is not None:
self.kv_transfer_params = \
sampling_params.extra_args.get("kv_transfer_params")
else:
raise ValueError(
"sampling_params and pooling_params can't both be unset")
self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
......@@ -63,11 +82,6 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: Connector-specific KV transfer parameters.
kv_params = (None if sampling_params.extra_args is None else
sampling_params.extra_args.get("kv_transfer_params"))
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
if self.mm_hashes:
......@@ -98,10 +112,12 @@ class Request:
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params),
sampling_params=request.sampling_params) \
if request.sampling_params else None,
cache_salt=request.cache_salt,
)
......@@ -141,7 +157,8 @@ class Request:
@property
def use_structured_output(self) -> bool:
return self.sampling_params.guided_decoding is not None
return self.sampling_params is not None and \
self.sampling_params.guided_decoding is not None
def record_event(
self,
......
......@@ -62,13 +62,15 @@ class StructuredOutputManager:
return
if TYPE_CHECKING:
assert request.sampling_params.guided_decoding is not None
assert request.sampling_params is not None and \
request.sampling_params.guided_decoding is not None
# Initialize the backend the first time it is needed.
#
# NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None:
assert request.sampling_params is not None
backend = request.sampling_params.guided_decoding.backend
vocab_size = self.vllm_config.model_config.get_vocab_size()
if backend == "xgrammar":
......
......@@ -10,9 +10,11 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
......@@ -27,7 +29,8 @@ class CachedRequestState:
prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
block_ids: tuple[list[int], ...]
......@@ -226,6 +229,8 @@ class InputBatch:
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
self.pooling_params: dict[str, PoolingParams] = {}
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
......@@ -269,77 +274,83 @@ class InputBatch:
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (sampling_params.min_tokens,
sampling_params.all_stop_token_ids)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
if sampling_params := request.sampling_params:
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (
sampling_params.min_tokens,
sampling_params.all_stop_token_ids)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
else:
assert request.pooling_params is not None
self.pooling_params[req_id] = request.pooling_params
# Add request lora ID
if request.lora_request:
......@@ -392,6 +403,7 @@ class InputBatch:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
......@@ -602,6 +614,25 @@ class InputBatch:
bad_words_token_ids=self.bad_words_token_ids,
)
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
......
......@@ -36,6 +36,7 @@ from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
......@@ -51,6 +52,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
......@@ -119,6 +121,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cache_config.cache_dtype]
self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None
self.max_model_len = model_config.max_model_len
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
......@@ -394,7 +397,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
pooling_params = new_req_data.pooling_params
if sampling_params and \
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
......@@ -406,6 +411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
......@@ -563,7 +569,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, Any], bool, torch.Tensor,
Optional[SpecDecodeMetadata]]:
Optional[SpecDecodeMetadata], np.ndarray]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
......@@ -750,7 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata)
spec_decode_metadata, num_scheduled_tokens)
def _compute_cascade_attn_prefix_len(
self,
......@@ -1197,6 +1203,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def _pool(
self,
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_recving: Optional[set[str]],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
"Either all or none of the requests in" \
" a batch must be pooling request"
extracted_hidden_states = list(
torch.split(hidden_states[:num_scheduled_tokens],
num_scheduled_tokens_np.tolist()))
pooling_metadata = self.input_batch.pooling_metadata
raw_pooler_output = self.model.pooler(
hidden_states=extracted_hidden_states,
pooling_metadata=pooling_metadata)
pooler_output: list[Optional[torch.Tensor]] = []
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
if seq_len == prompt_len:
pooler_output.append(raw_output.data.cpu())
else:
pooler_output.append(None)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
@torch.inference_mode()
def execute_model(
self,
......@@ -1214,7 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata) = (self._prepare_inputs(scheduler_output))
spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
......@@ -1284,7 +1336,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
# Run the decoder.
# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
attn_metadata,
......@@ -1326,6 +1378,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
all_gather_group=get_tp_group())
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
......@@ -1541,6 +1598,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
)
......@@ -1802,7 +1860,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
num_tokens: int,
capture_attn_cudagraph: bool = False,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
......@@ -1899,7 +1957,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.drafter.dummy_run(num_tokens)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states[logit_indices]
return hidden_states, hidden_states[logit_indices]
@torch.inference_mode()
def _dummy_sampler_run(
......@@ -1978,6 +2036,48 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return sampler_output
@torch.inference_mode()
def _dummy_pooler_run(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
hidden_states_list = list(
torch.split(hidden_states, num_scheduled_tokens_list))
req_num_tokens = num_tokens // num_reqs
dummy_metadata = PoolingMetadata(
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
device=self.device),
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device),
pooling_params=[PoolingParams()] * num_reqs)
try:
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up pooler with "
f"{num_reqs} dummy requests. Please try lowering "
"`max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine.") from e
else:
raise e
return pooler_output
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them.
......@@ -2048,13 +2148,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
hidden_states = self._dummy_run(self.max_num_tokens)
hidden_states, last_hidden_states \
= self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
sampler_output = self._dummy_sampler_run(hidden_states)
if self.is_pooling_model:
output = self._dummy_pooler_run(hidden_states)
else:
output = self._dummy_sampler_run(last_hidden_states)
else:
sampler_output = None
output = None
self._sync_device()
del hidden_states, sampler_output
del hidden_states, output
self.encoder_cache.clear()
gc.collect()
......
......@@ -273,9 +273,14 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=max_num_reqs))
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(num_tokens=max_num_reqs)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
else:
self.model_runner._dummy_sampler_run(
hidden_states=last_hidden_states)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
......
......@@ -231,6 +231,7 @@ class InputBatch:
self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet"
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
......
......@@ -386,6 +386,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add: list[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.sampling_params is not None,\
"Pooling is not supported in TPU yet"
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
......@@ -395,6 +397,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=None,
generator=None,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
......@@ -956,6 +959,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=None,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
)
# Check there are no new graphs compiled - all the graphs should be
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment