Unverified Commit 799397ee authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Support embedding models in V1 (#16188)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarMax de Bayser <maxdebayser@gmail.com>
Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 49599150
...@@ -15,7 +15,7 @@ from vllm.inputs import PromptType ...@@ -15,7 +15,7 @@ from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -221,7 +221,7 @@ class LLMEngine: ...@@ -221,7 +221,7 @@ class LLMEngine:
# Add the request to EngineCore. # Add the request to EngineCore.
self.engine_core.add_request(child_request) self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]: def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
if self.should_execute_dummy_batch: if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
......
...@@ -38,6 +38,7 @@ class LogprobsProcessor: ...@@ -38,6 +38,7 @@ class LogprobsProcessor:
tokenizer: Optional[AnyTokenizer], tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest, request: EngineCoreRequest,
) -> "LogprobsProcessor": ) -> "LogprobsProcessor":
assert request.sampling_params is not None
num_logprobs = request.sampling_params.logprobs num_logprobs = request.sampling_params.logprobs
num_prompt_logprobs = request.sampling_params.prompt_logprobs num_prompt_logprobs = request.sampling_params.prompt_logprobs
return cls( return cls(
......
...@@ -4,9 +4,12 @@ ...@@ -4,9 +4,12 @@
import asyncio import asyncio
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any, Optional, Union, cast
from vllm.outputs import CompletionOutput, RequestOutput import torch
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
...@@ -29,20 +32,22 @@ class RequestOutputCollector: ...@@ -29,20 +32,22 @@ class RequestOutputCollector:
def __init__(self, output_kind: RequestOutputKind): def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[Union[RequestOutput, Exception]] = None self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
Exception]] = None
self.ready = asyncio.Event() self.ready = asyncio.Event()
def put(self, output: Union[RequestOutput, Exception]) -> None: def put(self, output: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
"""Non-blocking put operation.""" """Non-blocking put operation."""
if self.output is None or isinstance(output, Exception): if self.output is None or isinstance(output, Exception):
self.output = output self.output = output
self.ready.set() self.ready.set()
elif isinstance(self.output, RequestOutput): elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
# This ensures that request outputs with different request indexes # This ensures that request outputs with different request indexes
# (if n > 1) do not override each other. # (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate) self.output.add(output, aggregate=self.aggregate)
async def get(self) -> RequestOutput: async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
"""Get operation blocks on put event.""" """Get operation blocks on put event."""
while (output := self.output) is None: while (output := self.output) is None:
await self.ready.wait() await self.ready.wait()
...@@ -52,7 +57,8 @@ class RequestOutputCollector: ...@@ -52,7 +57,8 @@ class RequestOutputCollector:
raise output raise output
return output return output
def get_nowait(self) -> Optional[RequestOutput]: def get_nowait(
self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
"""Non-blocking get operation.""" """Non-blocking get operation."""
output = self.output output = self.output
if output is not None: if output is not None:
...@@ -66,7 +72,7 @@ class RequestOutputCollector: ...@@ -66,7 +72,7 @@ class RequestOutputCollector:
@dataclass @dataclass
class OutputProcessorOutput: class OutputProcessorOutput:
request_outputs: list[RequestOutput] request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
reqs_to_abort: list[str] reqs_to_abort: list[str]
...@@ -81,8 +87,8 @@ class RequestState: ...@@ -81,8 +87,8 @@ class RequestState:
output_kind: RequestOutputKind, output_kind: RequestOutputKind,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: list[int], prompt_token_ids: list[int],
logprobs_processor: LogprobsProcessor, logprobs_processor: Optional[LogprobsProcessor],
detokenizer: IncrementalDetokenizer, detokenizer: Optional[IncrementalDetokenizer],
max_tokens_param: Optional[int], max_tokens_param: Optional[int],
arrival_time: float, arrival_time: float,
queue: Optional[RequestOutputCollector], queue: Optional[RequestOutputCollector],
...@@ -116,27 +122,39 @@ class RequestState: ...@@ -116,27 +122,39 @@ class RequestState:
queue: Optional[RequestOutputCollector], queue: Optional[RequestOutputCollector],
log_stats: bool, log_stats: bool,
) -> "RequestState": ) -> "RequestState":
if not request.sampling_params.detokenize:
tokenizer = None if sampling_params := request.sampling_params:
if not sampling_params.detokenize:
tokenizer = None
output_kind = sampling_params.output_kind
logprobs_processor = LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
)
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
)
max_tokens_param = sampling_params.max_tokens
else:
logprobs_processor = None
detokenizer = None
max_tokens_param = None
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
parent_req=parent_req, parent_req=parent_req,
request_index=request_index, request_index=request_index,
lora_name=(request.lora_request.name lora_name=(request.lora_request.name
if request.lora_request is not None else None), if request.lora_request is not None else None),
output_kind=request.sampling_params.output_kind, output_kind=output_kind,
prompt=prompt, prompt=prompt,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
logprobs_processor=LogprobsProcessor.from_new_request( logprobs_processor=logprobs_processor,
tokenizer=tokenizer, detokenizer=detokenizer,
request=request, max_tokens_param=max_tokens_param,
),
detokenizer=IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
),
max_tokens_param=(request.sampling_params.max_tokens if
request.sampling_params is not None else None),
arrival_time=request.arrival_time, arrival_time=request.arrival_time,
queue=queue, queue=queue,
log_stats=log_stats, log_stats=log_stats,
...@@ -145,11 +163,12 @@ class RequestState: ...@@ -145,11 +163,12 @@ class RequestState:
def make_request_output( def make_request_output(
self, self,
new_token_ids: list[int], new_token_ids: list[int],
pooling_output: Optional[torch.Tensor],
finish_reason: Optional[FinishReason], finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None], stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None, kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0, num_cached_tokens: int = 0,
) -> Optional[RequestOutput]: ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
finished = finish_reason is not None finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
...@@ -158,15 +177,20 @@ class RequestState: ...@@ -158,15 +177,20 @@ class RequestState:
# Only the final output is required in FINAL_ONLY mode. # Only the final output is required in FINAL_ONLY mode.
return None return None
completion_output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason)
request_id = self.request_id request_id = self.request_id
if pooling_output is not None:
return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)],
finished)
output = self._new_completion_output(new_token_ids, finish_reason,
stop_reason)
if self.parent_req is None: if self.parent_req is None:
outputs = [completion_output] outputs = [output]
else: else:
request_id, outputs, finished = self.parent_req.get_outputs( request_id, outputs, finished = self.parent_req.get_outputs(
request_id, completion_output) request_id, output)
if not outputs: if not outputs:
return None return None
...@@ -176,12 +200,21 @@ class RequestState: ...@@ -176,12 +200,21 @@ class RequestState:
def _new_request_output( def _new_request_output(
self, self,
request_id: str, request_id: str,
outputs: list[CompletionOutput], outputs: Union[list[CompletionOutput], list[PoolingOutput]],
finished: bool, finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None, kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0, num_cached_tokens: int = 0,
) -> RequestOutput: ) -> Union[RequestOutput, PoolingRequestOutput]:
if isinstance(outputs[0], PoolingOutput):
assert len(outputs) == 1
return PoolingRequestOutput(
request_id=request_id,
outputs=outputs[0],
prompt_token_ids=self.prompt_token_ids,
finished=finished,
)
assert self.logprobs_processor is not None
if self.output_kind == RequestOutputKind.DELTA: if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs # Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs() prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
...@@ -193,7 +226,7 @@ class RequestState: ...@@ -193,7 +226,7 @@ class RequestState:
prompt=self.prompt, prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids, prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
outputs=outputs, outputs=cast(list[CompletionOutput], outputs),
finished=finished, finished=finished,
kv_transfer_params=kv_transfer_params, kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens, num_cached_tokens=num_cached_tokens,
...@@ -206,6 +239,8 @@ class RequestState: ...@@ -206,6 +239,8 @@ class RequestState:
stop_reason: Union[int, str, None], stop_reason: Union[int, str, None],
) -> CompletionOutput: ) -> CompletionOutput:
assert self.detokenizer is not None
assert self.logprobs_processor is not None
finished = finish_reason is not None finished = finish_reason is not None
delta = self.output_kind == RequestOutputKind.DELTA delta = self.output_kind == RequestOutputKind.DELTA
...@@ -228,6 +263,13 @@ class RequestState: ...@@ -228,6 +263,13 @@ class RequestState:
finish_reason=str(finish_reason) if finished else None, finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None) stop_reason=stop_reason if finished else None)
def _new_pooling_output(
self,
pooling_output: torch.Tensor,
) -> PoolingOutput:
return PoolingOutput(data=pooling_output)
class OutputProcessor: class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs.""" """Process EngineCoreOutputs into RequestOutputs."""
...@@ -326,7 +368,8 @@ class OutputProcessor: ...@@ -326,7 +368,8 @@ class OutputProcessor:
within the loop below. within the loop below.
""" """
request_outputs: list[RequestOutput] = [] request_outputs: Union[list[RequestOutput],
list[PoolingRequestOutput]] = []
reqs_to_abort: list[str] = [] reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs: for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id req_id = engine_core_output.request_id
...@@ -341,25 +384,31 @@ class OutputProcessor: ...@@ -341,25 +384,31 @@ class OutputProcessor:
iteration_stats) iteration_stats)
new_token_ids = engine_core_output.new_token_ids new_token_ids = engine_core_output.new_token_ids
pooling_output = engine_core_output.pooling_output
finish_reason = engine_core_output.finish_reason finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks. if pooling_output is None:
stop_string = req_state.detokenizer.update( assert req_state.detokenizer is not None
new_token_ids, finish_reason == FinishReason.STOP) assert req_state.logprobs_processor is not None
if stop_string: # 2) Detokenize the token ids into text and perform stop checks.
finish_reason = FinishReason.STOP stop_string = req_state.detokenizer.update(
stop_reason = stop_string new_token_ids, finish_reason == FinishReason.STOP)
if stop_string:
# 3) Compute sample and prompt logprobs for request, if required. finish_reason = FinishReason.STOP
req_state.logprobs_processor.update_from_output(engine_core_output) stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request,
# if required.
req_state.logprobs_processor.update_from_output(
engine_core_output)
# 4) Create and handle RequestOutput objects. # 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output( if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason, new_token_ids, pooling_output, finish_reason, stop_reason,
kv_transfer_params, num_cached_tokens): kv_transfer_params, num_cached_tokens):
if req_state.queue is not None: if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate(). # AsyncLLM: put into queue for handling by generate().
......
...@@ -136,8 +136,8 @@ class Processor: ...@@ -136,8 +136,8 @@ class Processor:
Should raise ValueError if unsupported for API Server. Should raise ValueError if unsupported for API Server.
""" """
if not isinstance(params, SamplingParams): if isinstance(params, PoolingParams):
raise ValueError("V1 does not yet support Pooling models.") return
self._validate_logprobs(params) self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request) self._validate_sampling_params(params, lora_request)
...@@ -263,18 +263,22 @@ class Processor: ...@@ -263,18 +263,22 @@ class Processor:
if encoder_inputs is not None: if encoder_inputs is not None:
raise NotImplementedError raise NotImplementedError
assert isinstance(params, SamplingParams) sampling_params = None
# TODO: can we avoid cloning here in multiproc case? pooling_params = None
sampling_params = params.clone() if isinstance(params, SamplingParams):
# If unset max tokens, then generate up to the max_model_len. # TODO: can we avoid cloning here in multiproc case?
if sampling_params.max_tokens is None: sampling_params = params.clone()
sampling_params.max_tokens = ( # If unset max tokens, then generate up to the max_model_len.
self.model_config.max_model_len - if sampling_params.max_tokens is None:
len(decoder_inputs["prompt_token_ids"])) sampling_params.max_tokens = (
sampling_params.update_from_generation_config( self.model_config.max_model_len -
self.generation_config_fields, eos_token_id) len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_tokenizer( sampling_params.update_from_generation_config(
self.tokenizer.get_lora_tokenizer(lora_request)) self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
else:
pooling_params = params.clone()
# Multimodal related. # Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
...@@ -331,6 +335,7 @@ class Processor: ...@@ -331,6 +335,7 @@ class Processor:
mm_hashes=sorted_mm_hashes, mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions, mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
......
...@@ -481,8 +481,9 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -481,8 +481,9 @@ class PrometheusStatLogger(StatLoggerBase):
finished_request.num_prompt_tokens) finished_request.num_prompt_tokens)
self.histogram_num_generation_tokens_request.observe( self.histogram_num_generation_tokens_request.observe(
finished_request.num_generation_tokens) finished_request.num_generation_tokens)
self.histogram_max_tokens_request.observe( if finished_request.max_tokens_param:
finished_request.max_tokens_param) self.histogram_max_tokens_request.observe(
finished_request.max_tokens_param)
if self.gauge_lora_info is not None: if self.gauge_lora_info is not None:
running_lora_adapters = \ running_lora_adapters = \
......
...@@ -106,7 +106,6 @@ class IterationStats: ...@@ -106,7 +106,6 @@ class IterationStats:
self.num_generation_tokens += num_new_generation_tokens self.num_generation_tokens += num_new_generation_tokens
if is_prefilling: if is_prefilling:
assert num_new_generation_tokens > 0
self.num_prompt_tokens += prompt_len self.num_prompt_tokens += prompt_len
first_token_latency = self._time_since(req_stats.arrival_time) first_token_latency = self._time_since(req_stats.arrival_time)
......
...@@ -101,6 +101,9 @@ class ModelRunnerOutput: ...@@ -101,6 +101,9 @@ class ModelRunnerOutput:
# [prompt_len] # [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
# [num_reqs, hidden_size]
pooler_output: list[Optional[torch.Tensor]]
# [req_ids] # [req_ids]
finished_sending: Optional[set[str]] = None finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None finished_recving: Optional[set[str]] = None
...@@ -112,5 +115,6 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], ...@@ -112,5 +115,6 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
finished_sending=None, finished_sending=None,
finished_recving=None) finished_recving=None)
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.pooling_params import PoolingParams
@dataclass
class PoolingMetadata:
"""Tensors for pooling."""
prompt_lens: torch.Tensor
prompt_token_ids: Optional[torch.Tensor]
pooling_params: list[PoolingParams]
...@@ -5,6 +5,7 @@ import enum ...@@ -5,6 +5,7 @@ import enum
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
...@@ -25,7 +26,8 @@ class Request: ...@@ -25,7 +26,8 @@ class Request:
multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]], multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]], multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams, sampling_params: Optional[SamplingParams],
pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int], eos_token_id: Optional[int],
client_index: int = 0, client_index: int = 0,
lora_request: Optional["LoRARequest"] = None, lora_request: Optional["LoRARequest"] = None,
...@@ -35,18 +37,35 @@ class Request: ...@@ -35,18 +37,35 @@ class Request:
self.request_id = request_id self.request_id = request_id
self.client_index = client_index self.client_index = client_index
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.pooling_params = pooling_params
# Because of LoRA, the eos token id can be different for each request. # Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.structured_output_request = structured_output_request self.structured_output_request = structured_output_request
self.status = (RequestStatus.WAITING_FOR_FSM self.status = RequestStatus.WAITING
if sampling_params.guided_decoding is not None else if sampling_params and sampling_params.guided_decoding is not None:
RequestStatus.WAITING) self.status = RequestStatus.WAITING_FOR_FSM
self.events: list[EngineCoreEvent] = [] self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens # P/D: Connector-specific KV transfer parameters.
self.kv_transfer_params: Optional[dict[str, Any]] = None
if pooling_params is not None:
self.max_tokens = 1
elif sampling_params is not None:
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
if sampling_params.guided_decoding is not None:
self.status = RequestStatus.WAITING_FOR_FSM
if sampling_params.extra_args is not None:
self.kv_transfer_params = \
sampling_params.extra_args.get("kv_transfer_params")
else:
raise ValueError(
"sampling_params and pooling_params can't both be unset")
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids) self.num_prompt_tokens = len(self.prompt_token_ids)
...@@ -63,11 +82,6 @@ class Request: ...@@ -63,11 +82,6 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs) self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0 self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: Connector-specific KV transfer parameters.
kv_params = (None if sampling_params.extra_args is None else
sampling_params.extra_args.get("kv_transfer_params"))
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_inputs) == len(self.mm_positions)
if self.mm_hashes: if self.mm_hashes:
...@@ -98,10 +112,12 @@ class Request: ...@@ -98,10 +112,12 @@ class Request:
multi_modal_hashes=request.mm_hashes, multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders, multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id, eos_token_id=request.eos_token_id,
lora_request=request.lora_request, lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest( structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params), sampling_params=request.sampling_params) \
if request.sampling_params else None,
cache_salt=request.cache_salt, cache_salt=request.cache_salt,
) )
...@@ -141,7 +157,8 @@ class Request: ...@@ -141,7 +157,8 @@ class Request:
@property @property
def use_structured_output(self) -> bool: def use_structured_output(self) -> bool:
return self.sampling_params.guided_decoding is not None return self.sampling_params is not None and \
self.sampling_params.guided_decoding is not None
def record_event( def record_event(
self, self,
......
...@@ -62,13 +62,15 @@ class StructuredOutputManager: ...@@ -62,13 +62,15 @@ class StructuredOutputManager:
return return
if TYPE_CHECKING: if TYPE_CHECKING:
assert request.sampling_params.guided_decoding is not None assert request.sampling_params is not None and \
request.sampling_params.guided_decoding is not None
# Initialize the backend the first time it is needed. # Initialize the backend the first time it is needed.
# #
# NOTE: We only support a single backend. We do NOT support different # NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...). # backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None: if self.backend is None:
assert request.sampling_params is not None
backend = request.sampling_params.guided_decoding.backend backend = request.sampling_params.guided_decoding.backend
vocab_size = self.vllm_config.model_config.get_vocab_size() vocab_size = self.vllm_config.model_config.get_vocab_size()
if backend == "xgrammar": if backend == "xgrammar":
......
...@@ -10,9 +10,11 @@ import torch ...@@ -10,9 +10,11 @@ import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm.v1.worker.block_table import MultiGroupBlockTable
...@@ -27,7 +29,8 @@ class CachedRequestState: ...@@ -27,7 +29,8 @@ class CachedRequestState:
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs] mm_inputs: list[MultiModalKwargs]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator] generator: Optional[torch.Generator]
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
...@@ -226,6 +229,8 @@ class InputBatch: ...@@ -226,6 +229,8 @@ class InputBatch:
# This is updated each time the batch constituents change. # This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata()
self.pooling_params: dict[str, PoolingParams] = {}
@property @property
def req_ids(self) -> list[str]: def req_ids(self) -> list[str]:
# None elements should only be present transiently # None elements should only be present transiently
...@@ -269,77 +274,83 @@ class InputBatch: ...@@ -269,77 +274,83 @@ class InputBatch:
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index) self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params if sampling_params := request.sampling_params:
if sampling_params.sampling_type == SamplingType.GREEDY: if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero. # Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0 self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id) self.greedy_reqs.add(req_id)
else: else:
self.temperature_cpu[req_index] = sampling_params.temperature self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id) self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1: if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id) self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size: if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id) self.top_k_reqs.add(req_id)
else: else:
top_k = self.vocab_size top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[ self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS: if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id) self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0: if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id) self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[ self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0: if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id) self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[ self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0: if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id) self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens: if sampling_params.min_tokens:
self.min_tokens[req_index] = (sampling_params.min_tokens, self.min_tokens[req_index] = (
sampling_params.all_stop_token_ids) sampling_params.min_tokens,
sampling_params.all_stop_token_ids)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator. # NOTE(woosuk): self.generators should not include the requests that
if request.generator is not None: # do not have their own generator.
self.generators[req_index] = request.generator if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs if sampling_params.logprobs is not None:
if sampling_params.prompt_logprobs is not None: self.num_logprobs[req_id] = sampling_params.logprobs
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs if sampling_params.prompt_logprobs is not None:
if sampling_params.logit_bias is not None: self.num_prompt_logprobs[
self.logit_bias[req_index] = sampling_params.logit_bias req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
if sampling_params.allowed_token_ids: self.logit_bias[req_index] = sampling_params.logit_bias
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None: if sampling_params.allowed_token_ids:
# Lazy allocation for this tensor, which can be large. self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf. # False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, self.allowed_token_ids_mask_cpu_tensor[req_index][
self.vocab_size, sampling_params.allowed_token_ids] = False
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids: if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[ self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids req_index] = sampling_params.bad_words_token_ids
else:
assert request.pooling_params is not None
self.pooling_params[req_id] = request.pooling_params
# Add request lora ID # Add request lora ID
if request.lora_request: if request.lora_request:
...@@ -392,6 +403,7 @@ class InputBatch: ...@@ -392,6 +403,7 @@ class InputBatch:
# False means we don't fill with -inf. # False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None) self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
return req_index return req_index
def swap_states(self, i1: int, i2: int) -> None: def swap_states(self, i1: int, i2: int) -> None:
...@@ -602,6 +614,25 @@ class InputBatch: ...@@ -602,6 +614,25 @@ class InputBatch:
bad_words_token_ids=self.bad_words_token_ids, bad_words_token_ids=self.bad_words_token_ids,
) )
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty( prompt_token_ids_cpu_tensor = torch.empty(
......
...@@ -36,6 +36,7 @@ from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader ...@@ -36,6 +36,7 @@ from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
...@@ -51,6 +52,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, ...@@ -51,6 +52,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
SlidingWindowSpec) SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
...@@ -119,6 +121,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -119,6 +121,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cache_config.cache_dtype] cache_config.cache_dtype]
self.is_multimodal_model = model_config.is_multimodal_model self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs self.max_num_reqs = scheduler_config.max_num_seqs
...@@ -394,7 +397,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -394,7 +397,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED: pooling_params = new_req_data.pooling_params
if sampling_params and \
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device) generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed) generator.manual_seed(sampling_params.seed)
else: else:
...@@ -406,6 +411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -406,6 +411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mm_inputs=new_req_data.mm_inputs, mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions, mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator, generator=generator,
block_ids=new_req_data.block_ids, block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
...@@ -563,7 +569,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -563,7 +569,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, Any], bool, torch.Tensor, ) -> tuple[dict[str, Any], bool, torch.Tensor,
Optional[SpecDecodeMetadata]]: Optional[SpecDecodeMetadata], np.ndarray]:
""" """
:return: tuple[ :return: tuple[
attn_metadata: layer-to-attention_metadata mapping, attn_metadata: layer-to-attention_metadata mapping,
...@@ -750,7 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -750,7 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.set_active_loras(self.input_batch, num_scheduled_tokens) self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, attention_cuda_graphs, logits_indices, return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata) spec_decode_metadata, num_scheduled_tokens)
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
...@@ -1197,6 +1203,51 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1197,6 +1203,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32) dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def _pool(
self,
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_recving: Optional[set[str]],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
"Either all or none of the requests in" \
" a batch must be pooling request"
extracted_hidden_states = list(
torch.split(hidden_states[:num_scheduled_tokens],
num_scheduled_tokens_np.tolist()))
pooling_metadata = self.input_batch.pooling_metadata
raw_pooler_output = self.model.pooler(
hidden_states=extracted_hidden_states,
pooling_metadata=pooling_metadata)
pooler_output: list[Optional[torch.Tensor]] = []
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
if seq_len == prompt_len:
pooler_output.append(raw_output.data.cpu())
else:
pooler_output.append(None)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
...@@ -1214,7 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1214,7 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Prepare the decoder inputs. # Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices, (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata) = (self._prepare_inputs(scheduler_output)) spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
...@@ -1284,7 +1336,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1284,7 +1336,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
# Run the decoder. # Run the model.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
with set_forward_context( with set_forward_context(
attn_metadata, attn_metadata,
...@@ -1326,6 +1378,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1326,6 +1378,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
all_gather_group=get_tp_group()) all_gather_group=get_tp_group())
logits = None logits = None
else: else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving)
sample_hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output: if broadcast_pp_output:
...@@ -1541,6 +1598,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1541,6 +1598,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=spec_token_ids, spec_token_ids=spec_token_ids,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
) )
...@@ -1802,7 +1860,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1802,7 +1860,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, self,
num_tokens: int, num_tokens: int,
capture_attn_cudagraph: bool = False, capture_attn_cudagraph: bool = False,
) -> torch.Tensor: ) -> tuple[torch.Tensor, torch.Tensor]:
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
...@@ -1899,7 +1957,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1899,7 +1957,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens)
logit_indices = np.cumsum(num_scheduled_tokens) - 1 logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states[logit_indices] return hidden_states, hidden_states[logit_indices]
@torch.inference_mode() @torch.inference_mode()
def _dummy_sampler_run( def _dummy_sampler_run(
...@@ -1978,6 +2036,48 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1978,6 +2036,48 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
return sampler_output return sampler_output
@torch.inference_mode()
def _dummy_pooler_run(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
hidden_states_list = list(
torch.split(hidden_states, num_scheduled_tokens_list))
req_num_tokens = num_tokens // num_reqs
dummy_metadata = PoolingMetadata(
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
device=self.device),
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device),
pooling_params=[PoolingParams()] * num_reqs)
try:
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up pooler with "
f"{num_reqs} dummy requests. Please try lowering "
"`max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine.") from e
else:
raise e
return pooler_output
def profile_run(self) -> None: def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them. # TODO: handle encoder-decoder models once we support them.
...@@ -2048,13 +2148,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2048,13 +2148,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Cache the dummy encoder outputs. # Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
hidden_states = self._dummy_run(self.max_num_tokens) hidden_states, last_hidden_states \
= self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
sampler_output = self._dummy_sampler_run(hidden_states) if self.is_pooling_model:
output = self._dummy_pooler_run(hidden_states)
else:
output = self._dummy_sampler_run(last_hidden_states)
else: else:
sampler_output = None output = None
self._sync_device() self._sync_device()
del hidden_states, sampler_output del hidden_states, output
self.encoder_cache.clear() self.encoder_cache.clear()
gc.collect() gc.collect()
......
...@@ -273,9 +273,14 @@ class Worker(WorkerBase): ...@@ -273,9 +273,14 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs, max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run( hidden_states, last_hidden_states = \
num_tokens=max_num_reqs)) self.model_runner._dummy_run(num_tokens=max_num_reqs)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
else:
self.model_runner._dummy_sampler_run(
hidden_states=last_hidden_states)
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling. # the model initialization and profiling.
......
...@@ -231,6 +231,7 @@ class InputBatch: ...@@ -231,6 +231,7 @@ class InputBatch:
self.block_table.add_row(request.block_ids, req_index) self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet"
if sampling_params.sampling_type == SamplingType.GREEDY: if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero. # Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0 self.temperature_cpu[req_index] = -1.0
......
...@@ -386,6 +386,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -386,6 +386,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add: list[str] = [] req_ids_to_add: list[str] = []
# Add new requests to the cached states. # Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.sampling_params is not None,\
"Pooling is not supported in TPU yet"
req_id = new_req_data.req_id req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params sampling_params = new_req_data.sampling_params
...@@ -395,6 +397,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -395,6 +397,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
mm_inputs=new_req_data.mm_inputs, mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions, mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None,
generator=None, generator=None,
block_ids=new_req_data.block_ids, block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
...@@ -956,6 +959,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -956,6 +959,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=None, spec_token_ids=None,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
) )
# Check there are no new graphs compiled - all the graphs should be # Check there are no new graphs compiled - all the graphs should be
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment