Unverified Commit 799397ee authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Support embedding models in V1 (#16188)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarMax de Bayser <maxdebayser@gmail.com>
Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 49599150
......@@ -150,6 +150,7 @@ def create_request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
......@@ -183,6 +184,7 @@ def create_model_runner_output(
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
......@@ -10,6 +10,7 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
......@@ -46,7 +47,7 @@ def _compare_objs(obj1, obj2):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)):
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
......@@ -201,6 +202,7 @@ def _construct_cached_request_state(req_id_suffix: int):
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_inputs=[],
mm_positions=[],
block_ids=([], ),
......
......@@ -122,6 +122,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
pooling_params=None,
block_ids=([0], ),
num_computed_tokens=0,
lora_request=None,
......
......@@ -4496,11 +4496,31 @@ class VllmConfig:
if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn:
logger.warning_once(
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
logger.info("full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.long_prefill_token_threshold = 0
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False
if (self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events
and not self.cache_config.enable_prefix_caching):
......
......@@ -1041,7 +1041,7 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine.
if use_v1:
self._set_default_args_v1(usage_context)
self._set_default_args_v1(usage_context, model_config)
else:
self._set_default_args_v0(model_config)
......@@ -1349,13 +1349,7 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No Embedding Models so far.
if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",
recommend_to_remove=False)
return False
# No Encoder-Decoder, not all Mamba so far.
# No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures,
recommend_to_remove=False)
......@@ -1523,15 +1517,38 @@ class EngineArgs:
if self.max_num_seqs is None:
self.max_num_seqs = 256
def _set_default_args_v1(self, usage_context: UsageContext) -> None:
def _set_default_args_v1(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills.
self.enable_chunked_prefill = True
# V1 always uses chunked prefills and prefix caching
# for non-pooling tasks.
# For pooling tasks the default is False
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
else:
pooling_type = model_config.pooler_config.pooling_type
# TODO: when encoder models are supported we'll have to
# check for causal attention here.
incremental_prefill_supported = (pooling_type is not None and
pooling_type.lower() == "last")
# V1 enables prefix caching by default.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
action = "Enabling" if \
incremental_prefill_supported else "Disabling"
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = incremental_prefill_supported
logger.info("(%s) chunked prefill by default", action)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = incremental_prefill_supported
logger.info("(%s) prefix caching by default", action)
if not self.enable_chunked_prefill:
self.max_num_batched_tokens = model_config.max_model_len
# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default
......
......@@ -1266,7 +1266,7 @@ class LLM:
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
tokenizer = self.llm_engine.get_tokenizer()
tokenizer = self.get_tokenizer()
def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict):
......
......@@ -9,6 +9,7 @@ from typing import Final, Literal, Optional, Union, cast
import jinja2
import numpy as np
import torch
from fastapi import Request
from typing_extensions import assert_never
......@@ -39,7 +40,8 @@ def _get_data(
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
pooling_bytes = np.array(output.data, dtype="float32").tobytes()
pt_float32 = output.data.to(dtype=torch.float32)
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
return base64.b64encode(pooling_bytes).decode("utf-8")
assert_never(encoding_format)
......
......@@ -10,11 +10,15 @@ import torch.nn.functional as F
from typing_extensions import assert_never
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
class PoolingType(IntEnum):
......@@ -75,15 +79,18 @@ class SimplePooler(nn.Module):
def get_prompt_lens(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
raise NotImplementedError
......@@ -93,7 +100,7 @@ class SimplePooler(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
......@@ -106,11 +113,19 @@ class CLSPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
result = []
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with CLS pooling"
result.append(req_state[0])
return result
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
return hidden_states[first_token_flat_indices]
......@@ -120,9 +135,12 @@ class LastPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
if isinstance(hidden_states, list):
return [h[-1] for h in hidden_states]
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
......@@ -133,11 +151,17 @@ class AllPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with ALL pooling"
return hidden_states
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len in prompt_lens:
......@@ -151,11 +175,20 @@ class MeanPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
result = []
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with mean pooling"
result.append(torch.mean(req_state, dim=0,
dtype=torch.float32))
return result
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
......@@ -184,30 +217,53 @@ class StepPool(SimplePooler):
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids
def get_prompt_token_ids(
self,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata):
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [
torch.tensor(seq_data_i.prompt_token_ids)
for seq_data_i in pooling_metadata.seq_data.values()
]
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
returned_token_ids = self.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
hidden_states = hidden_states[:, returned_token_ids]
pooled_data: list[torch.Tensor] = []
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with mean pooling"
pooled_data = hidden_states
else:
offset = 0
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
offset += prompt_len
pooled_data.append(pooled_data_i)
pooled_data = []
returned_token_ids = self.returned_token_ids
step_tag_id = self.step_tag_id
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len, seq_data_i in zip(prompt_lens,
pooling_metadata.seq_data.values()):
pooled_data_i = hidden_states[offset:offset + prompt_len]
if step_tag_id is not None:
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
for data, token_id in zip(pooled_data, prompt_token_ids):
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
offset += prompt_len
pooled_data.append(pooled_data_i)
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
......@@ -230,10 +286,17 @@ class PoolerHead(nn.Module):
else:
pooled_data = pooled_data.to(torch.float32)
dimensions_list = [
pooling_param.dimensions
for _, pooling_param in pooling_metadata.seq_groups
]
if isinstance(pooling_metadata, V0PoolingMetadata):
dimensions_list = [
pooling_param.dimensions
for _, pooling_param in pooling_metadata.seq_groups
]
else:
assert isinstance(pooled_data, list)
dimensions_list = [
pooling_param.dimensions
for pooling_param in pooling_metadata.pooling_params
]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
......@@ -325,20 +388,41 @@ class ClassifierPooler(nn.Module):
raise NotImplementedError(f"task={config.task!r} is not supported"
" with the classification pooler")
def get_prompt_lens(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
def forward(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools sentence pair scores from the hidden_states."""
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens = PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
pooled_data = list[torch.Tensor]()
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with classifier"
pooled_data = hidden_states
else:
offset = 0
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
offset += prompt_len
pooled_data.append(pooled_data_i)
offset = 0
pooled_data_lst = []
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
for pooled_data_i in pooled_data:
if self.pooler is not None:
final_shape_tensor = self.pooler(pooled_data_i)
......@@ -346,7 +430,6 @@ class ClassifierPooler(nn.Module):
final_shape_tensor = self.classifier(pooled_data_i)
pooled_data_lst.append(final_shape_tensor)
offset += prompt_len
pooled_output = torch.stack(pooled_data_lst)
......
......@@ -446,8 +446,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
softmax=False)
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
class BertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
......
......@@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsCrossEncoding
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
......@@ -270,7 +270,8 @@ class ModernBertPooler(nn.Module):
return pooled_output
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -375,7 +375,12 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
) -> Optional[PoolerOutput]:
hidden_states = self._pooler.extract_states(hidden_states,
pooling_metadata)
logits, _ = self.score(hidden_states)
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
......
......@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Optional
import msgspec
from vllm.sampling_params import RequestOutputKind
if TYPE_CHECKING:
from vllm.config import ModelConfig
......@@ -23,6 +25,7 @@ class PoolingParams(
dimensions: Optional[int] = None
additional_data: Optional[Any] = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
......@@ -52,3 +55,7 @@ class PoolingParams(
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"additional_metadata={self.additional_data})")
def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
"For pooling output_kind has to be FINAL_ONLY"
......@@ -146,7 +146,8 @@ class KVCacheManager:
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
or request.sampling_params.prompt_logprobs is not None):
or (request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None)):
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed
......
......@@ -14,6 +14,7 @@ if TYPE_CHECKING:
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
......@@ -26,7 +27,8 @@ class NewRequestData:
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
......@@ -44,6 +46,7 @@ class NewRequestData:
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
......
......@@ -402,6 +402,15 @@ class Scheduler(SchedulerInterface):
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
......@@ -707,6 +716,7 @@ class Scheduler(SchedulerInterface):
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
new_running: list[Request] = []
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
......@@ -724,7 +734,8 @@ class Scheduler(SchedulerInterface):
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
......@@ -776,8 +787,17 @@ class Scheduler(SchedulerInterface):
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, self.max_model_len,
pooler_output)
if stopped:
kv_transfer_params = self._free_request(request)
# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None and logprobs:
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
......@@ -802,7 +822,8 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or kv_transfer_params:
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
......@@ -812,6 +833,7 @@ class Scheduler(SchedulerInterface):
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.v1.request import Request, RequestStatus
def check_stop(request: Request, max_model_len: int) -> bool:
def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool:
if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
if request.pooling_params:
if pooler_output is not None:
request.status = RequestStatus.FINISHED_STOPPED
return True
return False
sampling_params = request.sampling_params
assert sampling_params is not None
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
......
......@@ -7,10 +7,12 @@ from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
......@@ -50,7 +52,8 @@ class EngineCoreRequest(
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
......@@ -104,6 +107,8 @@ class EngineCoreOutput(
new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
pooling_output: Optional[torch.Tensor] = None
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None
......
......@@ -17,7 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
......@@ -228,8 +228,7 @@ class AsyncLLM(EngineClient):
if self.errored:
raise EngineDeadError()
assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1"
is_pooling = isinstance(params, PoolingParams)
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
......@@ -240,7 +239,7 @@ class AsyncLLM(EngineClient):
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority, data_parallel_rank)
if params.n == 1:
if is_pooling or params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
return queue
......@@ -443,7 +442,7 @@ class AsyncLLM(EngineClient):
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
def encode(
async def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
......@@ -451,8 +450,75 @@ class AsyncLLM(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
):
raise ValueError("Not Supported on V1 yet.")
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
q = await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
assert isinstance(out, PoolingRequestOutput)
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
yield out
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
......
......@@ -60,7 +60,6 @@ class EngineCore:
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Optional[Callable] = None):
assert vllm_config.model_config.runner_type != "pooling"
# plugins need to be loaded at the engine/scheduler level too
from vllm.plugins import load_general_plugins
......
......@@ -50,6 +50,8 @@ class IncrementalDetokenizer:
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":
assert request.sampling_params is not None
if tokenizer is None:
# No tokenizer => skipping detokenization.
return IncrementalDetokenizer()
......@@ -70,6 +72,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
# Stop strings
params = request.sampling_params
assert params is not None
self.stop = stop = params.stop
self.include_stop_str_in_output = params.include_stop_str_in_output
......@@ -164,6 +167,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
super().__init__(request)
sampling_params = request.sampling_params
assert sampling_params is not None
self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens
......@@ -245,20 +249,20 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
super().__init__(request)
self.tokenizer = tokenizer
params = request.sampling_params
assert params is not None
# Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.sampling_params.
skip_special_tokens,
skip_special_tokens=params.skip_special_tokens,
))
self.token_ids.extend(request.prompt_token_ids)
self.prompt_len = len(request.prompt_token_ids)
params = request.sampling_params
self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = (
params.spaces_between_special_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment