Unverified Commit 2554b27b authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

[V0 Deprecation] Remove pooling model support in V0 (#23434)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 934bebf1
...@@ -347,7 +347,7 @@ class MQLLMEngine: ...@@ -347,7 +347,7 @@ class MQLLMEngine:
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
try: try:
self.engine.add_lora(request.lora_request) lora_loaded = self.engine.add_lora(request.lora_request)
except BaseException as e: except BaseException as e:
# Send back an error if the adater fails to load # Send back an error if the adater fails to load
rpc_err = RPCError(request_id=request.request_id, rpc_err = RPCError(request_id=request.request_id,
...@@ -357,7 +357,8 @@ class MQLLMEngine: ...@@ -357,7 +357,8 @@ class MQLLMEngine:
return return
# Otherwise, send back the successful load message # Otherwise, send back the successful load message
self._send_outputs( self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id)) RPCAdapterLoadedResponse(request_id=request.request_id,
lora_loaded=lora_loaded))
def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
is_sleeping = self.is_sleeping() is_sleeping = self.is_sleeping()
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncGenerator, Iterable, Mapping, Optional, Union from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.config import DecodingConfig, ModelConfig, VllmConfig
...@@ -224,6 +224,7 @@ class EngineClient(ABC): ...@@ -224,6 +224,7 @@ class EngineClient(ABC):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.""" """Generate outputs for a request from a pooling model."""
... ...
...@@ -320,7 +321,7 @@ class EngineClient(ABC): ...@@ -320,7 +321,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
... ...
......
...@@ -1156,8 +1156,7 @@ class LLM: ...@@ -1156,8 +1156,7 @@ class LLM:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
"token_type_ids", None)):
params = pooling_params.clone() params = pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids) compressed = compress_token_type_ids(token_type_ids)
params.extra_kwargs = {"compressed_token_type_ids": compressed} params.extra_kwargs = {"compressed_token_type_ids": compressed}
......
...@@ -7,7 +7,6 @@ from typing import Any, Optional, Union ...@@ -7,7 +7,6 @@ from typing import Any, Optional, Union
from fastapi import Request from fastapi import Request
from vllm import envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
...@@ -229,8 +228,7 @@ class ServingScores(OpenAIServing): ...@@ -229,8 +228,7 @@ class ServingScores(OpenAIServing):
params=default_pooling_params, params=default_pooling_params,
lora_request=lora_request) lora_request=lora_request)
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
"token_type_ids", None)):
pooling_params = default_pooling_params.clone() pooling_params = default_pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids) compressed = compress_token_type_ids(token_type_ids)
pooling_params.extra_kwargs = { pooling_params.extra_kwargs = {
......
...@@ -174,9 +174,6 @@ class TokenInputs(TypedDict): ...@@ -174,9 +174,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: list[int] prompt_token_ids: list[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
prompt: NotRequired[str] prompt: NotRequired[str]
""" """
The original prompt text corresponding to the token IDs, if available. The original prompt text corresponding to the token IDs, if available.
...@@ -190,7 +187,6 @@ class TokenInputs(TypedDict): ...@@ -190,7 +187,6 @@ class TokenInputs(TypedDict):
def token_inputs( def token_inputs(
prompt_token_ids: list[int], prompt_token_ids: list[int],
token_type_ids: Optional[list[int]] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
) -> TokenInputs: ) -> TokenInputs:
...@@ -200,8 +196,6 @@ def token_inputs( ...@@ -200,8 +196,6 @@ def token_inputs(
if prompt is not None: if prompt is not None:
inputs["prompt"] = prompt inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if cache_salt is not None: if cache_salt is not None:
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
......
...@@ -355,7 +355,6 @@ class InputPreprocessor: ...@@ -355,7 +355,6 @@ class InputPreprocessor:
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs] inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"): if multi_modal_data := parsed_content.get("multi_modal_data"):
...@@ -368,10 +367,7 @@ class InputPreprocessor: ...@@ -368,10 +367,7 @@ class InputPreprocessor:
mm_hash_overrides=mm_hash_overrides, mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(prompt_token_ids=prompt_token_ids)
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
...@@ -387,7 +383,6 @@ class InputPreprocessor: ...@@ -387,7 +383,6 @@ class InputPreprocessor:
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs] inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"): if multi_modal_data := parsed_content.get("multi_modal_data"):
...@@ -400,10 +395,7 @@ class InputPreprocessor: ...@@ -400,10 +395,7 @@ class InputPreprocessor:
mm_hash_overrides=mm_hash_overrides, mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
......
...@@ -13,17 +13,12 @@ import torch.nn.functional as F ...@@ -13,17 +13,12 @@ import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig, PoolerConfig from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.utils import current_stream, resolve_obj_by_qualname from vllm.utils import current_stream, resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingCursor from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingFn = Callable[ PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]] Union[torch.Tensor, list[torch.Tensor]]]
...@@ -127,36 +122,23 @@ def get_prompt_lens( ...@@ -127,36 +122,23 @@ def get_prompt_lens(
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata): return pooling_metadata.prompt_lens
return pooling_metadata.prompt_lens
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states[0].device).prompt_lens
def get_prompt_token_ids( def get_prompt_token_ids(
pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata): assert pooling_metadata.prompt_token_ids is not None, (
assert pooling_metadata.prompt_token_ids is not None, ( "Please set `requires_token_ids=True` in `get_pooling_updates`")
"Please set `requires_token_ids=True` in `get_pooling_updates`")
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [ return [
torch.tensor(seq_data_i.prompt_token_ids) pooling_metadata.prompt_token_ids[i, :num]
for seq_data_i in pooling_metadata.seq_data.values() for i, num in enumerate(pooling_metadata.prompt_lens)
] ]
def get_pooling_params( def get_pooling_params(
pooling_metadata: PoolingMetadata) -> list[PoolingParams]: pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
if isinstance(pooling_metadata, V0PoolingMetadata): pooling_params = pooling_metadata.pooling_params
pooling_params = [p for _, p in pooling_metadata.seq_groups]
else:
pooling_params = pooling_metadata.pooling_params
return pooling_params return pooling_params
......
...@@ -24,9 +24,9 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, ...@@ -24,9 +24,9 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
......
...@@ -15,10 +15,10 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, ...@@ -15,10 +15,10 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
build_output, get_prompt_lens, build_output, get_prompt_lens,
get_prompt_token_ids) get_prompt_token_ids)
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import PoolerOutput from vllm.sequence import PoolerOutput
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
......
...@@ -22,9 +22,9 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding ...@@ -22,9 +22,9 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
import torch
from vllm.pooling_params import PoolingParams
from vllm.utils import is_pin_memory_available
from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor
class PoolingMetadata:
"""Metadata for pooling operations in the Pooler layer.
This class holds the necessary information for pooling operations,
providing context for how to perform pooling and other related operations.
Attributes:
seq_groups: List of (seq_ids, pooling_params).
seq_data: A mapping of sequence ID to additional sequence data.
prompt_lens: List of the lengths of each prompt.
"""
def __init__(
self,
seq_groups: list[tuple[list[int], PoolingParams]],
seq_data: dict[int, Any], # Specific data related to sequences
prompt_lens: list[int],
pooling_cursor: Optional[PoolingCursor] = None) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor
def __repr__(self) -> str:
return ("PoolingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens})")
def __getitem__(self, indices: slice):
return PoolingMetadata(
seq_groups=self.seq_groups[indices],
seq_data=dict(list(self.seq_data.items())[indices]),
prompt_lens=self.prompt_lens[indices],
pooling_cursor=None
if self.pooling_cursor is None else self.pooling_cursor[indices],
)
def build_pooling_cursor(self, num_scheduled_tokens: list[int],
device: torch.device):
prompt_lens = torch.tensor(self.prompt_lens, device="cpu")
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
prompt_lens,
device=device)
@dataclass
class PoolingTensors:
"""Tensors for pooling."""
prompt_lens: torch.Tensor
@classmethod
def from_pooling_metadata(
cls,
pooling_metadata: "PoolingMetadata",
device: torch.device,
) -> "PoolingTensors":
"""
Create PoolingTensors from PoolingMetadata.
Args:
pooling_metadata: PoolingMetadata instance to convert.
device: Device to store the tensors.
"""
# Convert prompt lengths to tensor
pin_memory = is_pin_memory_available()
prompt_lens_t = torch.tensor(
pooling_metadata.prompt_lens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
return cls(prompt_lens=prompt_lens_t.to(device=device,
non_blocking=True), )
...@@ -913,9 +913,6 @@ class MultiModalInputs(TypedDict): ...@@ -913,9 +913,6 @@ class MultiModalInputs(TypedDict):
prompt_token_ids: list[int] prompt_token_ids: list[int]
"""The processed token IDs which includes placeholder tokens.""" """The processed token IDs which includes placeholder tokens."""
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargsOptionalItems mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
...@@ -946,6 +943,3 @@ class MultiModalEncDecInputs(MultiModalInputs): ...@@ -946,6 +943,3 @@ class MultiModalEncDecInputs(MultiModalInputs):
encoder_prompt_token_ids: list[int] encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt.""" """The processed token IDs of the encoder prompt."""
encoder_token_type_ids: NotRequired[list[int]]
"""The token type IDs of the encoder prompt."""
...@@ -508,12 +508,6 @@ class Sequence: ...@@ -508,12 +508,6 @@ class Sequence:
return [0] * len(self.inputs["prompt_embeds"]) return [0] * len(self.inputs["prompt_embeds"])
return self.inputs["prompt_token_ids"] return self.inputs["prompt_token_ids"]
@property
def token_type_ids(self) -> list[int]:
if self.inputs["type"] == "embeds":
return []
return self.inputs.get("token_type_ids", [])
@property @property
def multi_modal_data(self) -> MultiModalKwargs: def multi_modal_data(self) -> MultiModalKwargs:
if self.inputs["type"] == "multimodal": if self.inputs["type"] == "multimodal":
...@@ -765,10 +759,6 @@ class SequenceGroup: ...@@ -765,10 +759,6 @@ class SequenceGroup:
return (self.encoder_seq.prompt_token_ids return (self.encoder_seq.prompt_token_ids
if self.encoder_seq is not None else None) if self.encoder_seq is not None else None)
@property
def token_type_ids(self) -> Optional[list[int]]:
return self.first_seq.token_type_ids
@property @property
def multi_modal_data(self) -> MultiModalKwargs: def multi_modal_data(self) -> MultiModalKwargs:
if self.first_seq.multi_modal_data: if self.first_seq.multi_modal_data:
...@@ -972,7 +962,6 @@ class SequenceGroupMetadata( ...@@ -972,7 +962,6 @@ class SequenceGroupMetadata(
computed_block_nums: Optional[list[int]] = None computed_block_nums: Optional[list[int]] = None
state: Optional[SequenceGroupState] = msgspec.field( state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState()) default_factory=lambda: SequenceGroupState())
token_type_ids: Optional[list[int]] = None
multi_modal_data: Optional[MultiModalKwargs] = None multi_modal_data: Optional[MultiModalKwargs] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
encoder_seq_data: Optional[SequenceData] = None encoder_seq_data: Optional[SequenceData] = None
......
...@@ -24,8 +24,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, ...@@ -24,8 +24,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry) MultiModalRegistry)
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder, ModelInputForGPUBuilder,
...@@ -161,7 +160,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -161,7 +160,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[PoolerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in " raise ValueError("num_steps > 1 is not supported in "
"EncoderDecoderModelRunner") "EncoderDecoderModelRunner")
......
...@@ -86,7 +86,6 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -86,7 +86,6 @@ class ModelInputForGPU(ModelRunnerInputBase):
input_tokens: Optional[torch.Tensor] = None input_tokens: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None inputs_embeds: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None lora_mapping: Optional["LoRAMapping"] = None
...@@ -192,7 +191,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -192,7 +191,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens[0].clear() # type: ignore self.input_tokens[0].clear() # type: ignore
self.inputs_embeds = None # type: ignore self.inputs_embeds = None # type: ignore
self.input_positions[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore
...@@ -219,7 +217,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -219,7 +217,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
input_tokens: Optional[List[List[int]]] = None, input_tokens: Optional[List[List[int]]] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
input_positions: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None,
# The sequence length (may be capped to the sliding window). # The sequence length (may be capped to the sliding window).
...@@ -284,12 +281,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -284,12 +281,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)): for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear() self.input_positions[seq_id].clear()
if token_types:
self.token_types = token_types
else:
for seq_id in range(len(self.seq_ids)):
self.token_types[seq_id].clear()
self.mrope_input_positions = None self.mrope_input_positions = None
if seq_lens: if seq_lens:
...@@ -348,7 +339,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -348,7 +339,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds self.inputs_embeds = inputs_embeds
self.input_positions = input_positions or [] self.input_positions = input_positions or []
self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or [] self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or [] self.orig_seq_lens = orig_seq_lens or []
...@@ -376,7 +366,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -376,7 +366,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)]
self.token_types = [[] for _ in range(self.n_seqs)]
self.mrope_input_positions = None self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs
...@@ -400,7 +389,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -400,7 +389,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
f"inputs_embeds.shape=" f"inputs_embeds.shape="
f"{getattr(self.inputs_embeds, 'shape', None)}, " f"{getattr(self.inputs_embeds, 'shape', None)}, "
f"input_positions={self.input_positions}, " f"input_positions={self.input_positions}, "
f"token_types={self.token_types}, "
f"mrope_input_positions={self.mrope_input_positions}, " f"mrope_input_positions={self.mrope_input_positions}, "
f"seq_lens={self.seq_lens}, " f"seq_lens={self.seq_lens}, "
f"orig_seq_lens={self.orig_seq_lens}, " f"orig_seq_lens={self.orig_seq_lens}, "
...@@ -522,8 +510,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -522,8 +510,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_embeds = seq_data.get_token_embeddings( prompt_embeds = seq_data.get_token_embeddings(
)[context_len:seq_len] )[context_len:seq_len]
token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len() inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
...@@ -531,8 +517,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -531,8 +517,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds inter_data.inputs_embeds = prompt_embeds
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend(
token_types if token_types else [])
inter_data.query_lens[seq_idx] = seq_len - context_len inter_data.query_lens[seq_idx] = seq_len - context_len
if seq_data.mrope_position_delta is not None: if seq_data.mrope_position_delta is not None:
...@@ -590,8 +574,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -590,8 +574,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_idx][uncomputed_start:] seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[ inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][uncomputed_start:] seq_idx][uncomputed_start:]
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
uncomputed_start:]
context_len = prefix_cache_len context_len = prefix_cache_len
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
...@@ -606,8 +588,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -606,8 +588,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_idx][-1:] seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[ inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:] seq_idx][-1:]
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
-1:]
inter_data.query_lens[seq_idx] = 1 inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
...@@ -802,12 +782,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -802,12 +782,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = list[int]() input_tokens = list[int]()
inputs_embeds_list = list[torch.Tensor]() inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]()
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens) input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None: if inter_data.inputs_embeds is not None:
inputs_embeds_list.append( inputs_embeds_list.append(
inter_data.inputs_embeds.to( inter_data.inputs_embeds.to(
...@@ -890,11 +867,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -890,11 +867,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device,
self.runner.pin_memory) \
if token_types else None
if mrope_input_positions is not None: if mrope_input_positions is not None:
for idx in range(3): for idx in range(3):
mrope_input_positions[idx].extend( mrope_input_positions[idx].extend(
...@@ -951,7 +923,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -951,7 +923,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
token_types=token_types_tensor,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
seq_lens=seq_lens, seq_lens=seq_lens,
query_lens=query_lens, query_lens=query_lens,
......
...@@ -13,10 +13,9 @@ from vllm.config import VllmConfig ...@@ -13,10 +13,9 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.interfaces import supports_transcription from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import ( from vllm.model_executor.models.interfaces_base import is_text_generation_model
is_pooling_model, is_text_generation_model)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, SupportedTask
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
...@@ -241,20 +240,11 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -241,20 +240,11 @@ class ModelRunnerBase(ABC, Generic[T]):
return supported_tasks return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return list(model.pooler.get_supported_tasks())
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]() tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate": if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks()) tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks) return tuple(tasks)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
import torch
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
ModelInputForGPUBuilder)
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
"""
Used by the PoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class PoolingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
):
super().__init__(vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"PoolingModelRunner does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
virtual_engine = model_input.virtual_engine
# Pooling models are (ab-)used also to integrate non text models that
# are not autoregressive (PrithviGeosaptialMAE).
# These model might not use attention and do not really have a prefill
# and decode phase. The model input is processed in one shot and both
# decode_metadata and prefill_metadata would be None for such models.
# See the PlaceholderAttentionMetadata class.
# TODO: Figure out if cuda_graph is of any use for these models and
# explore how to leverage it.
if (prefill_meta is None and decode_meta is not None
and decode_meta.use_cuda_graph):
if model_input.inputs_embeds is None:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True)
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
cross_enc_kwargs = {}
if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types
with set_forward_context(model_input.attn_metadata, self.vllm_config,
virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
device=self.device,
),
**cross_enc_kwargs,
**seqlen_agnostic_kwargs,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Only perform pooling in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return []
pooling_metadata = model_input.pooling_metadata
assert pooling_metadata is not None
pooling_metadata.build_pooling_cursor(
num_scheduled_tokens=pooling_metadata.prompt_lens,
device=hidden_or_intermediate_states.device)
return [
self.model.pooler(hidden_states=hidden_or_intermediate_states,
pooling_metadata=pooling_metadata)
]
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForGPUWithPoolingMetadata:
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
return dataclasses.replace(model_input,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
assert pooling_params is not None
task = pooling_params.task
assert task is not None, "You did not set `task` in the API"
model = cast(VllmModelForPooling, self.model)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata
...@@ -30,7 +30,6 @@ from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, ...@@ -30,7 +30,6 @@ from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
...@@ -83,9 +82,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -83,9 +82,7 @@ class Worker(LocalOrDistributedWorkerBase):
else {"return_hidden_states": True} else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.runner_type == "pooling": if self.model_config.is_encoder_decoder:
ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass( self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
...@@ -99,7 +96,6 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -99,7 +96,6 @@ class Worker(LocalOrDistributedWorkerBase):
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: List[CacheEngine] self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as pooling models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment