"vscode:/vscode.git/clone" did not exist on "314af8617c628be1eeb89626bff1be5bc8f81e6b"
Unverified Commit 799397ee authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Support embedding models in V1 (#16188)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarMax de Bayser <maxdebayser@gmail.com>
Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 49599150
...@@ -150,6 +150,7 @@ def create_request( ...@@ -150,6 +150,7 @@ def create_request(
request_id=f"id-{request_id}", request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=None, multi_modal_inputs=None,
multi_modal_placeholders=None, multi_modal_placeholders=None,
multi_modal_hashes=None, multi_modal_hashes=None,
...@@ -183,6 +184,7 @@ def create_model_runner_output( ...@@ -183,6 +184,7 @@ def create_model_runner_output(
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=None,
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
) )
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -46,7 +47,7 @@ def _compare_objs(obj1, obj2): ...@@ -46,7 +47,7 @@ def _compare_objs(obj1, obj2):
for a_i, b_i in zip(a.block_tables, b.block_tables): for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i) _compare_objs(a_i, b_i)
is_same = True is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)): elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
_compare_objs(a, b) _compare_objs(a, b)
is_same = True # if we make it here must be same is_same = True # if we make it here must be same
elif a == b: elif a == b:
...@@ -201,6 +202,7 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -201,6 +202,7 @@ def _construct_cached_request_state(req_id_suffix: int):
req_id=f"req_id_{req_id_suffix}", req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(), sampling_params=_create_sampling_params(),
pooling_params=None,
mm_inputs=[], mm_inputs=[],
mm_positions=[], mm_positions=[],
block_ids=([], ), block_ids=([], ),
......
...@@ -122,6 +122,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -122,6 +122,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
pooling_params=None,
block_ids=([0], ), block_ids=([0], ),
num_computed_tokens=0, num_computed_tokens=0,
lora_request=None, lora_request=None,
......
...@@ -4496,11 +4496,31 @@ class VllmConfig: ...@@ -4496,11 +4496,31 @@ class VllmConfig:
if self.compilation_config.full_cuda_graph and \ if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn: not self.model_config.disable_cascade_attn:
logger.warning_once( logger.info("full_cuda_graph is not supported with "
"full_cuda_graph is not supported with " "cascade attention. Disabling cascade attention.")
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True self.model_config.disable_cascade_attn = True
disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.long_prefill_token_threshold = 0
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False
if (self.kv_events_config is not None if (self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events and self.kv_events_config.enable_kv_cache_events
and not self.cache_config.enable_prefix_caching): and not self.cache_config.enable_prefix_caching):
......
...@@ -1041,7 +1041,7 @@ class EngineArgs: ...@@ -1041,7 +1041,7 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine. # Set default arguments for V0 or V1 Engine.
if use_v1: if use_v1:
self._set_default_args_v1(usage_context) self._set_default_args_v1(usage_context, model_config)
else: else:
self._set_default_args_v0(model_config) self._set_default_args_v0(model_config)
...@@ -1349,13 +1349,7 @@ class EngineArgs: ...@@ -1349,13 +1349,7 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No Embedding Models so far. # No Mamba or Encoder-Decoder so far.
if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",
recommend_to_remove=False)
return False
# No Encoder-Decoder, not all Mamba so far.
if not model_config.is_v1_compatible: if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures, _raise_or_fallback(feature_name=model_config.architectures,
recommend_to_remove=False) recommend_to_remove=False)
...@@ -1523,15 +1517,38 @@ class EngineArgs: ...@@ -1523,15 +1517,38 @@ class EngineArgs:
if self.max_num_seqs is None: if self.max_num_seqs is None:
self.max_num_seqs = 256 self.max_num_seqs = 256
def _set_default_args_v1(self, usage_context: UsageContext) -> None: def _set_default_args_v1(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine.""" """Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills. # V1 always uses chunked prefills and prefix caching
self.enable_chunked_prefill = True # for non-pooling tasks.
# For pooling tasks the default is False
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
else:
pooling_type = model_config.pooler_config.pooling_type
# TODO: when encoder models are supported we'll have to
# check for causal attention here.
incremental_prefill_supported = (pooling_type is not None and
pooling_type.lower() == "last")
# V1 enables prefix caching by default. action = "Enabling" if \
if self.enable_prefix_caching is None: incremental_prefill_supported else "Disabling"
self.enable_prefix_caching = True
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = incremental_prefill_supported
logger.info("(%s) chunked prefill by default", action)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = incremental_prefill_supported
logger.info("(%s) prefix caching by default", action)
if not self.enable_chunked_prefill:
self.max_num_batched_tokens = model_config.max_model_len
# V1 should use the new scheduler by default. # V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default # Swap it only if this arg is set to the original V0 default
......
...@@ -1266,7 +1266,7 @@ class LLM: ...@@ -1266,7 +1266,7 @@ class LLM:
# the tokenizer for models such as # the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs # lists of tokens to the `text` and `text_pair` kwargs
tokenizer = self.llm_engine.get_tokenizer() tokenizer = self.get_tokenizer()
def ensure_str(prompt: SingletonPrompt): def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict): if isinstance(prompt, dict):
......
...@@ -9,6 +9,7 @@ from typing import Final, Literal, Optional, Union, cast ...@@ -9,6 +9,7 @@ from typing import Final, Literal, Optional, Union, cast
import jinja2 import jinja2
import numpy as np import numpy as np
import torch
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -39,7 +40,8 @@ def _get_data( ...@@ -39,7 +40,8 @@ def _get_data(
elif encoding_format == "base64": elif encoding_format == "base64":
# Force to use float32 for base64 encoding # Force to use float32 for base64 encoding
# to match the OpenAI python client behavior # to match the OpenAI python client behavior
pooling_bytes = np.array(output.data, dtype="float32").tobytes() pt_float32 = output.data.to(dtype=torch.float32)
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
return base64.b64encode(pooling_bytes).decode("utf-8") return base64.b64encode(pooling_bytes).decode("utf-8")
assert_never(encoding_format) assert_never(encoding_format)
......
...@@ -10,11 +10,15 @@ import torch.nn.functional as F ...@@ -10,11 +10,15 @@ import torch.nn.functional as F
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import ModelConfig, PoolerConfig from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata, from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingTensors) PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
get_cross_encoder_activation_function) get_cross_encoder_activation_function)
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
class PoolingType(IntEnum): class PoolingType(IntEnum):
...@@ -75,15 +79,18 @@ class SimplePooler(nn.Module): ...@@ -75,15 +79,18 @@ class SimplePooler(nn.Module):
def get_prompt_lens( def get_prompt_lens(
self, self,
hidden_states: torch.Tensor, hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata( return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens pooling_metadata, hidden_states.device).prompt_lens
def extract_states( def extract_states(
self, self,
hidden_states: torch.Tensor, hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
raise NotImplementedError raise NotImplementedError
...@@ -93,7 +100,7 @@ class SimplePooler(nn.Module): ...@@ -93,7 +100,7 @@ class SimplePooler(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata) pooled_data = self.extract_states(hidden_states, pooling_metadata)
...@@ -106,11 +113,19 @@ class CLSPool(SimplePooler): ...@@ -106,11 +113,19 @@ class CLSPool(SimplePooler):
def extract_states( def extract_states(
self, self,
hidden_states: torch.Tensor, hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
result = []
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with CLS pooling"
result.append(req_state[0])
return result
first_token_flat_indices = torch.zeros_like(prompt_lens) first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
return hidden_states[first_token_flat_indices] return hidden_states[first_token_flat_indices]
...@@ -120,9 +135,12 @@ class LastPool(SimplePooler): ...@@ -120,9 +135,12 @@ class LastPool(SimplePooler):
def extract_states( def extract_states(
self, self,
hidden_states: torch.Tensor, hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
if isinstance(hidden_states, list):
return [h[-1] for h in hidden_states]
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
...@@ -133,11 +151,17 @@ class AllPool(SimplePooler): ...@@ -133,11 +151,17 @@ class AllPool(SimplePooler):
def extract_states( def extract_states(
self, self,
hidden_states: torch.Tensor, hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with ALL pooling"
return hidden_states
offset = 0 offset = 0
pooled_data = list[torch.Tensor]() pooled_data = list[torch.Tensor]()
for prompt_len in prompt_lens: for prompt_len in prompt_lens:
...@@ -151,11 +175,20 @@ class MeanPool(SimplePooler): ...@@ -151,11 +175,20 @@ class MeanPool(SimplePooler):
def extract_states( def extract_states(
self, self,
hidden_states: torch.Tensor, hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
result = []
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with mean pooling"
result.append(torch.mean(req_state, dim=0,
dtype=torch.float32))
return result
# Use float32 for torch.cumsum in MeanPool, # Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly. # otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
...@@ -184,30 +217,53 @@ class StepPool(SimplePooler): ...@@ -184,30 +217,53 @@ class StepPool(SimplePooler):
self.step_tag_id = step_tag_id self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids self.returned_token_ids = returned_token_ids
def get_prompt_token_ids(
self,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata):
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [
torch.tensor(seq_data_i.prompt_token_ids)
for seq_data_i in pooling_metadata.seq_data.values()
]
def extract_states( def extract_states(
self, self,
hidden_states: torch.Tensor, hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
returned_token_ids = self.returned_token_ids pooled_data: list[torch.Tensor] = []
if returned_token_ids is not None and len(returned_token_ids) > 0:
hidden_states = hidden_states[:, returned_token_ids]
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with mean pooling"
pooled_data = hidden_states
else:
offset = 0
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
offset += prompt_len
pooled_data.append(pooled_data_i)
pooled_data = []
returned_token_ids = self.returned_token_ids
step_tag_id = self.step_tag_id step_tag_id = self.step_tag_id
offset = 0 for data, token_id in zip(pooled_data, prompt_token_ids):
pooled_data = list[torch.Tensor]() if returned_token_ids is not None and len(returned_token_ids) > 0:
for prompt_len, seq_data_i in zip(prompt_lens, data = data[:, returned_token_ids]
pooling_metadata.seq_data.values()):
pooled_data_i = hidden_states[offset:offset + prompt_len]
if step_tag_id is not None:
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
offset += prompt_len if step_tag_id is not None:
pooled_data.append(pooled_data_i) data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data return pooled_data
...@@ -230,10 +286,17 @@ class PoolerHead(nn.Module): ...@@ -230,10 +286,17 @@ class PoolerHead(nn.Module):
else: else:
pooled_data = pooled_data.to(torch.float32) pooled_data = pooled_data.to(torch.float32)
dimensions_list = [ if isinstance(pooling_metadata, V0PoolingMetadata):
pooling_param.dimensions dimensions_list = [
for _, pooling_param in pooling_metadata.seq_groups pooling_param.dimensions
] for _, pooling_param in pooling_metadata.seq_groups
]
else:
assert isinstance(pooled_data, list)
dimensions_list = [
pooling_param.dimensions
for pooling_param in pooling_metadata.pooling_params
]
if any(d is not None for d in dimensions_list): if any(d is not None for d in dimensions_list):
# change the output dimension # change the output dimension
assert len(pooled_data) == len(dimensions_list) assert len(pooled_data) == len(dimensions_list)
...@@ -325,20 +388,41 @@ class ClassifierPooler(nn.Module): ...@@ -325,20 +388,41 @@ class ClassifierPooler(nn.Module):
raise NotImplementedError(f"task={config.task!r} is not supported" raise NotImplementedError(f"task={config.task!r} is not supported"
" with the classification pooler") " with the classification pooler")
def get_prompt_lens(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> PoolerOutput:
"""Pools sentence pair scores from the hidden_states.""" """Pools sentence pair scores from the hidden_states."""
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens = PoolingTensors.from_pooling_metadata( pooled_data = list[torch.Tensor]()
pooling_metadata, hidden_states.device).prompt_lens if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with classifier"
pooled_data = hidden_states
else:
offset = 0
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
offset += prompt_len
pooled_data.append(pooled_data_i)
offset = 0 offset = 0
pooled_data_lst = [] pooled_data_lst = []
for prompt_len in prompt_lens: for pooled_data_i in pooled_data:
pooled_data_i = hidden_states[offset:offset + prompt_len]
if self.pooler is not None: if self.pooler is not None:
final_shape_tensor = self.pooler(pooled_data_i) final_shape_tensor = self.pooler(pooled_data_i)
...@@ -346,7 +430,6 @@ class ClassifierPooler(nn.Module): ...@@ -346,7 +430,6 @@ class ClassifierPooler(nn.Module):
final_shape_tensor = self.classifier(pooled_data_i) final_shape_tensor = self.classifier(pooled_data_i)
pooled_data_lst.append(final_shape_tensor) pooled_data_lst.append(final_shape_tensor)
offset += prompt_len
pooled_output = torch.stack(pooled_data_lst) pooled_output = torch.stack(pooled_data_lst)
......
...@@ -446,8 +446,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -446,8 +446,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
softmax=False) softmax=False)
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, class BertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsQuant): SupportsCrossEncoding, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities. """A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for This class encapsulates the BertModel and provides an interface for
......
...@@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix from .utils import WeightsMapper, maybe_prefix
...@@ -270,7 +270,8 @@ class ModernBertPooler(nn.Module): ...@@ -270,7 +270,8 @@ class ModernBertPooler(nn.Module):
return pooled_output return pooled_output
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -375,7 +375,12 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, ...@@ -375,7 +375,12 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
hidden_states = self._pooler.extract_states(hidden_states, hidden_states = self._pooler.extract_states(hidden_states,
pooling_metadata) pooling_metadata)
logits, _ = self.score(hidden_states)
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata) pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [ pooled_outputs = [
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
......
...@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Optional
import msgspec import msgspec
from vllm.sampling_params import RequestOutputKind
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -23,6 +25,7 @@ class PoolingParams( ...@@ -23,6 +25,7 @@ class PoolingParams(
dimensions: Optional[int] = None dimensions: Optional[int] = None
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance.""" """Returns a deep copy of the PoolingParams instance."""
...@@ -52,3 +55,7 @@ class PoolingParams( ...@@ -52,3 +55,7 @@ class PoolingParams(
return (f"PoolingParams(" return (f"PoolingParams("
f"dimensions={self.dimensions}, " f"dimensions={self.dimensions}, "
f"additional_metadata={self.additional_data})") f"additional_metadata={self.additional_data})")
def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
"For pooling output_kind has to be FINAL_ONLY"
...@@ -146,7 +146,8 @@ class KVCacheManager: ...@@ -146,7 +146,8 @@ class KVCacheManager:
# Prefix caching is disabled or # Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching. # When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching if (not self.enable_caching
or request.sampling_params.prompt_logprobs is not None): or (request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None)):
return self.create_empty_block_list(), 0 return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed # The block hashes for the request may already be computed
......
...@@ -14,6 +14,7 @@ if TYPE_CHECKING: ...@@ -14,6 +14,7 @@ if TYPE_CHECKING:
KVConnectorMetadata) KVConnectorMetadata)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -26,7 +27,8 @@ class NewRequestData: ...@@ -26,7 +27,8 @@ class NewRequestData:
mm_inputs: list[MultiModalKwargs] mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str] mm_hashes: list[str]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
...@@ -44,6 +46,7 @@ class NewRequestData: ...@@ -44,6 +46,7 @@ class NewRequestData:
mm_hashes=request.mm_hashes, mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions, mm_positions=request.mm_positions,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids, block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens, num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request, lora_request=request.lora_request,
......
...@@ -402,6 +402,15 @@ class Scheduler(SchedulerInterface): ...@@ -402,6 +402,15 @@ class Scheduler(SchedulerInterface):
< num_new_tokens): < num_new_tokens):
num_new_tokens = ( num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold) self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
...@@ -707,6 +716,7 @@ class Scheduler(SchedulerInterface): ...@@ -707,6 +716,7 @@ class Scheduler(SchedulerInterface):
logprobs = model_runner_output.logprobs logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
new_running: list[Request] = [] new_running: list[Request] = []
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
...@@ -724,7 +734,8 @@ class Scheduler(SchedulerInterface): ...@@ -724,7 +734,8 @@ class Scheduler(SchedulerInterface):
continue continue
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index] generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
...@@ -776,8 +787,17 @@ class Scheduler(SchedulerInterface): ...@@ -776,8 +787,17 @@ class Scheduler(SchedulerInterface):
del new_token_ids[num_new:] # Trim new tokens if needed. del new_token_ids[num_new:] # Trim new tokens if needed.
break break
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, self.max_model_len,
pooler_output)
if stopped:
kv_transfer_params = self._free_request(request)
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None and logprobs: if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode), # NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1. # the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1) new_logprobs = logprobs.slice(req_index, req_index + 1)
...@@ -802,7 +822,8 @@ class Scheduler(SchedulerInterface): ...@@ -802,7 +822,8 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or kv_transfer_params: if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request. # Add EngineCoreOutput for this Request.
outputs[request.client_index].append( outputs[request.client_index].append(
...@@ -812,6 +833,7 @@ class Scheduler(SchedulerInterface): ...@@ -812,6 +833,7 @@ class Scheduler(SchedulerInterface):
finish_reason=request.get_finished_reason(), finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs, new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors, new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason, stop_reason=request.stop_reason,
events=request.take_events(), events=request.take_events(),
kv_transfer_params=kv_transfer_params, kv_transfer_params=kv_transfer_params,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
def check_stop(request: Request, max_model_len: int) -> bool: def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool:
if (request.num_tokens >= max_model_len if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens): or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True return True
if request.pooling_params:
if pooler_output is not None:
request.status = RequestStatus.FINISHED_STOPPED
return True
return False
sampling_params = request.sampling_params sampling_params = request.sampling_params
assert sampling_params is not None
last_token_id = request.output_token_ids[-1] last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id): and last_token_id == request.eos_token_id):
......
...@@ -7,10 +7,12 @@ from collections.abc import Sequence ...@@ -7,10 +7,12 @@ from collections.abc import Sequence
from typing import Any, Optional, Union from typing import Any, Optional, Union
import msgspec import msgspec
import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.outputs import LogprobsLists, LogprobsTensors
...@@ -50,7 +52,8 @@ class EngineCoreRequest( ...@@ -50,7 +52,8 @@ class EngineCoreRequest(
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]] mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]] mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
eos_token_id: Optional[int] eos_token_id: Optional[int]
arrival_time: float arrival_time: float
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
...@@ -104,6 +107,8 @@ class EngineCoreOutput( ...@@ -104,6 +107,8 @@ class EngineCoreOutput(
new_logprobs: Optional[LogprobsLists] = None new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
pooling_output: Optional[torch.Tensor] = None
finish_reason: Optional[FinishReason] = None finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None events: Optional[list[EngineCoreEvent]] = None
......
...@@ -17,7 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor ...@@ -17,7 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -228,8 +228,7 @@ class AsyncLLM(EngineClient): ...@@ -228,8 +228,7 @@ class AsyncLLM(EngineClient):
if self.errored: if self.errored:
raise EngineDeadError() raise EngineDeadError()
assert isinstance(params, SamplingParams), \ is_pooling = isinstance(params, PoolingParams)
"Pooling is not supported in V1"
# Create a new output collector for the request. # Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind) queue = RequestOutputCollector(output_kind=params.output_kind)
...@@ -240,7 +239,7 @@ class AsyncLLM(EngineClient): ...@@ -240,7 +239,7 @@ class AsyncLLM(EngineClient):
tokenization_kwargs, trace_headers, prompt_adapter_request, tokenization_kwargs, trace_headers, prompt_adapter_request,
priority, data_parallel_rank) priority, data_parallel_rank)
if params.n == 1: if is_pooling or params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue) await self._add_request(request, prompt_str, None, 0, queue)
return queue return queue
...@@ -443,7 +442,7 @@ class AsyncLLM(EngineClient): ...@@ -443,7 +442,7 @@ class AsyncLLM(EngineClient):
stat_logger.record(scheduler_stats=scheduler_stats, stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats) iteration_stats=iteration_stats)
def encode( async def encode(
self, self,
prompt: PromptType, prompt: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
...@@ -451,8 +450,75 @@ class AsyncLLM(EngineClient): ...@@ -451,8 +450,75 @@ class AsyncLLM(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
): ) -> AsyncGenerator[PoolingRequestOutput, None]:
raise ValueError("Not Supported on V1 yet.") """
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
q = await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
assert isinstance(out, PoolingRequestOutput)
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
yield out
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
async def get_vllm_config(self) -> VllmConfig: async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config return self.vllm_config
......
...@@ -60,7 +60,6 @@ class EngineCore: ...@@ -60,7 +60,6 @@ class EngineCore:
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
executor_fail_callback: Optional[Callable] = None): executor_fail_callback: Optional[Callable] = None):
assert vllm_config.model_config.runner_type != "pooling"
# plugins need to be loaded at the engine/scheduler level too # plugins need to be loaded at the engine/scheduler level too
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
......
...@@ -50,6 +50,8 @@ class IncrementalDetokenizer: ...@@ -50,6 +50,8 @@ class IncrementalDetokenizer:
request: EngineCoreRequest, request: EngineCoreRequest,
) -> "IncrementalDetokenizer": ) -> "IncrementalDetokenizer":
assert request.sampling_params is not None
if tokenizer is None: if tokenizer is None:
# No tokenizer => skipping detokenization. # No tokenizer => skipping detokenization.
return IncrementalDetokenizer() return IncrementalDetokenizer()
...@@ -70,6 +72,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): ...@@ -70,6 +72,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
# Stop strings # Stop strings
params = request.sampling_params params = request.sampling_params
assert params is not None
self.stop = stop = params.stop self.stop = stop = params.stop
self.include_stop_str_in_output = params.include_stop_str_in_output self.include_stop_str_in_output = params.include_stop_str_in_output
...@@ -164,6 +167,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): ...@@ -164,6 +167,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
super().__init__(request) super().__init__(request)
sampling_params = request.sampling_params sampling_params = request.sampling_params
assert sampling_params is not None
self.request_id = request.request_id self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens self.skip_special_tokens = sampling_params.skip_special_tokens
...@@ -245,20 +249,20 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): ...@@ -245,20 +249,20 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
super().__init__(request) super().__init__(request)
self.tokenizer = tokenizer self.tokenizer = tokenizer
params = request.sampling_params
assert params is not None
# Metadata for incremental detokenization. # Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = ( self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens( convert_prompt_ids_to_tokens(
tokenizer=tokenizer, tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids, prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.sampling_params. skip_special_tokens=params.skip_special_tokens,
skip_special_tokens,
)) ))
self.token_ids.extend(request.prompt_token_ids) self.token_ids.extend(request.prompt_token_ids)
self.prompt_len = len(request.prompt_token_ids) self.prompt_len = len(request.prompt_token_ids)
params = request.sampling_params
self.skip_special_tokens = params.skip_special_tokens self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = ( self.spaces_between_special_tokens = (
params.spaces_between_special_tokens) params.spaces_between_special_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment