Unverified Commit d2f058e7 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Rename embedding classes to pooling (#10801)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f877a7d1
...@@ -10,7 +10,7 @@ prompts = [ ...@@ -10,7 +10,7 @@ prompts = [
# Create an LLM. # Create an LLM.
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs. # Generate embedding. The output is a list of PoolingRequestOutputs.
outputs = model.encode(prompts) outputs = model.encode(prompts)
# Print the outputs. # Print the outputs.
for output in outputs: for output in outputs:
......
...@@ -3,7 +3,7 @@ from typing import List ...@@ -3,7 +3,7 @@ from typing import List
import pytest import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams from vllm import LLM, PoolingParams, PoolingRequestOutput
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
MODEL_NAME = "intfloat/e5-mistral-7b-instruct" MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
...@@ -43,8 +43,8 @@ def llm(): ...@@ -43,8 +43,8 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput], def assert_outputs_equal(o1: List[PoolingRequestOutput],
o2: List[EmbeddingRequestOutput]): o2: List[PoolingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2] assert [o.outputs for o in o1] == [o.outputs for o in o2]
......
...@@ -3,7 +3,7 @@ import warnings ...@@ -3,7 +3,7 @@ import warnings
import pytest import pytest
import torch.cuda import torch.cuda
from vllm.model_executor.models import (is_embedding_model, from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model, is_text_generation_model,
supports_multimodal) supports_multimodal)
from vllm.model_executor.models.adapters import as_embedding_model from vllm.model_executor.models.adapters import as_embedding_model
...@@ -31,7 +31,7 @@ def test_registry_imports(model_arch): ...@@ -31,7 +31,7 @@ def test_registry_imports(model_arch):
# All vLLM models should be convertible to an embedding model # All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls) embed_model = as_embedding_model(model_cls)
assert is_embedding_model(embed_model) assert is_pooling_model(embed_model)
if model_arch in _MULTIMODAL_MODELS: if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls) assert supports_multimodal(model_cls)
......
...@@ -8,10 +8,10 @@ from vllm.attention.backends.abstract import AttentionBackend ...@@ -8,10 +8,10 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import StatefulModelInput from vllm.worker.multi_step_model_runner import StatefulModelInput
from vllm.worker.pooling_model_runner import (
ModelInputForGPUWithPoolingMetadata)
class MockAttentionBackend(AttentionBackend): class MockAttentionBackend(AttentionBackend):
......
...@@ -7,8 +7,8 @@ from vllm.entrypoints.llm import LLM ...@@ -7,8 +7,8 @@ from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput, from vllm.outputs import (CompletionOutput, PoolingOutput,
EmbeddingRequestOutput, RequestOutput) PoolingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -25,8 +25,8 @@ __all__ = [ ...@@ -25,8 +25,8 @@ __all__ = [
"SamplingParams", "SamplingParams",
"RequestOutput", "RequestOutput",
"CompletionOutput", "CompletionOutput",
"EmbeddingOutput", "PoolingOutput",
"EmbeddingRequestOutput", "PoolingRequestOutput",
"LLMEngine", "LLMEngine",
"EngineArgs", "EngineArgs",
"AsyncLLMEngine", "AsyncLLMEngine",
...@@ -34,3 +34,26 @@ __all__ = [ ...@@ -34,3 +34,26 @@ __all__ = [
"initialize_ray_cluster", "initialize_ray_cluster",
"PoolingParams", "PoolingParams",
] ]
def __getattr__(name: str):
import warnings
if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingOutput
if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingRequestOutput
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
...@@ -359,7 +359,7 @@ class ModelConfig: ...@@ -359,7 +359,7 @@ class ModelConfig:
# NOTE: Listed from highest to lowest priority, # NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them # in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures), "generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures), "embedding": ModelRegistry.is_pooling_model(architectures),
} }
supported_tasks_lst: List[_Task] = [ supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported task for task, is_supported in task_support.items() if is_supported
......
...@@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest ...@@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel ...@@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel
class AsyncStream: class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request """A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator.""" that can be iterated over asynchronously via an async generator."""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
...@@ -83,7 +83,7 @@ class AsyncStream: ...@@ -83,7 +83,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None: Exception]) -> None:
if not self._finished: if not self._finished:
self._queue.put_nowait(item) self._queue.put_nowait(item)
...@@ -103,7 +103,7 @@ class AsyncStream: ...@@ -103,7 +103,7 @@ class AsyncStream:
async def generator( async def generator(
self self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
try: try:
while True: while True:
result = await self._queue.get() result = await self._queue.get()
...@@ -154,7 +154,7 @@ class RequestTracker: ...@@ -154,7 +154,7 @@ class RequestTracker:
def process_request_output(self, def process_request_output(self,
request_output: Union[RequestOutput, request_output: Union[RequestOutput,
EmbeddingRequestOutput], PoolingRequestOutput],
*, *,
verbose: bool = False) -> None: verbose: bool = False) -> None:
"""Process a request output from the engine.""" """Process a request output from the engine."""
...@@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
async def step_async( async def step_async(
self, virtual_engine: int self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible. The workers are ran asynchronously if possible.
...@@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[ ) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]: RequestOutput, PoolingRequestOutput], None]]:
... ...
@overload @overload
...@@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[ ) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]: RequestOutput, PoolingRequestOutput], None]]:
... ...
@deprecate_kwargs( @deprecate_kwargs(
...@@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient):
priority: int = 0, priority: int = 0,
*, *,
inputs: Optional[PromptType] = None, # DEPRECATED inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if inputs is not None: if inputs is not None:
prompt = inputs prompt = inputs
assert prompt is not None and params is not None assert prompt is not None and params is not None
...@@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the Generate outputs for a request. This method is a coroutine. It adds the
...@@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient):
Only applicable with priority scheduling. Only applicable with priority scheduling.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `PoolingRequestOutput` objects from the LLMEngine
for the request. for the request.
Details: Details:
...@@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
): ):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput) yield LLMEngine.validate_output(output, PoolingRequestOutput)
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort a request. """Abort a request.
......
...@@ -40,7 +40,7 @@ from vllm.model_executor.guided_decoding import ( ...@@ -40,7 +40,7 @@ from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor) get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: ...@@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
@dataclass @dataclass
...@@ -112,7 +112,7 @@ class SchedulerContext: ...@@ -112,7 +112,7 @@ class SchedulerContext:
def __init__(self, multi_step_stream_outputs: bool = False): def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque() self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput, self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = [] PoolingRequestOutput]] = []
self.seq_group_metadata_list: Optional[ self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None
...@@ -1314,7 +1314,7 @@ class LLMEngine: ...@@ -1314,7 +1314,7 @@ class LLMEngine:
else: else:
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png .. figure:: https://i.imgur.com/sv2HssD.png
......
...@@ -35,7 +35,7 @@ from vllm.inputs.preprocess import InputPreprocessor ...@@ -35,7 +35,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
...@@ -495,7 +495,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -495,7 +495,7 @@ class MQLLMEngineClient(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
... ...
@overload @overload
...@@ -507,7 +507,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -507,7 +507,7 @@ class MQLLMEngineClient(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
... ...
@deprecate_kwargs( @deprecate_kwargs(
...@@ -524,7 +524,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -524,7 +524,7 @@ class MQLLMEngineClient(EngineClient):
priority: int = 0, priority: int = 0,
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the Generate outputs for a request. This method is a coroutine. It adds the
...@@ -540,7 +540,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -540,7 +540,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `PoolingRequestOutput` objects from the LLMEngine
for the request. for the request.
""" """
if inputs is not None: if inputs is not None:
...@@ -549,7 +549,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -549,7 +549,7 @@ class MQLLMEngineClient(EngineClient):
and request_id is not None) and request_id is not None)
return cast( return cast(
AsyncGenerator[EmbeddingRequestOutput, None], AsyncGenerator[PoolingRequestOutput, None],
self._process_request(prompt, self._process_request(prompt,
pooling_params, pooling_params,
request_id, request_id,
...@@ -567,7 +567,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -567,7 +567,7 @@ class MQLLMEngineClient(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]: PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out. # If already dead, error out.
......
...@@ -11,8 +11,7 @@ from vllm.inputs.preprocess import InputPreprocessor ...@@ -11,8 +11,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
RequestOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
...@@ -209,7 +208,7 @@ class EngineClient(ABC): ...@@ -209,7 +208,7 @@ class EngineClient(ABC):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.""" """Generate outputs for a request from an embedding model."""
... ...
......
...@@ -26,7 +26,7 @@ from vllm.logger import init_logger ...@@ -26,7 +26,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions) GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
...@@ -679,7 +679,7 @@ class LLM: ...@@ -679,7 +679,7 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload # LEGACY: multi (prompt + optional token ids) @overload # LEGACY: multi (prompt + optional token ids)
...@@ -691,7 +691,7 @@ class LLM: ...@@ -691,7 +691,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload # LEGACY: single (token ids + optional prompt) @overload # LEGACY: single (token ids + optional prompt)
...@@ -704,7 +704,7 @@ class LLM: ...@@ -704,7 +704,7 @@ class LLM:
prompt_token_ids: List[int], prompt_token_ids: List[int],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload # LEGACY: multi (token ids + optional prompt) @overload # LEGACY: multi (token ids + optional prompt)
...@@ -717,7 +717,7 @@ class LLM: ...@@ -717,7 +717,7 @@ class LLM:
prompt_token_ids: List[List[int]], prompt_token_ids: List[List[int]],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload # LEGACY: single or multi token ids [pos-only] @overload # LEGACY: single or multi token ids [pos-only]
...@@ -728,7 +728,7 @@ class LLM: ...@@ -728,7 +728,7 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]], prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload @overload
...@@ -741,7 +741,7 @@ class LLM: ...@@ -741,7 +741,7 @@ class LLM:
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@deprecate_kwargs( @deprecate_kwargs(
...@@ -759,7 +759,7 @@ class LLM: ...@@ -759,7 +759,7 @@ class LLM:
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
This class automatically batches the given prompts, considering This class automatically batches the given prompts, considering
...@@ -778,7 +778,7 @@ class LLM: ...@@ -778,7 +778,7 @@ class LLM:
generation, if any. generation, if any.
Returns: Returns:
A list of ``EmbeddingRequestOutput`` objects containing the A list of ``PoolingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts. generated embeddings in the same order as the input prompts.
Note: Note:
...@@ -821,7 +821,7 @@ class LLM: ...@@ -821,7 +821,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput) PoolingRequestOutput)
def score( def score(
self, self,
...@@ -832,7 +832,7 @@ class LLM: ...@@ -832,7 +832,7 @@ class LLM:
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
"""Generates similarity scores for all pairs <text,text_pair>. """Generates similarity scores for all pairs <text,text_pair>.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
...@@ -854,7 +854,7 @@ class LLM: ...@@ -854,7 +854,7 @@ class LLM:
generation, if any. generation, if any.
Returns: Returns:
A list of ``EmbeddingRequestOutput`` objects containing the A list of ``PoolingRequestOutput`` objects containing the
generated scores in the same order as the input prompts. generated scores in the same order as the input prompts.
""" """
task = self.llm_engine.model_config.task task = self.llm_engine.model_config.task
...@@ -943,7 +943,7 @@ class LLM: ...@@ -943,7 +943,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput) PoolingRequestOutput)
def start_profile(self) -> None: def start_profile(self) -> None:
self.llm_engine.start_profile() self.llm_engine.start_profile()
...@@ -1085,7 +1085,7 @@ class LLM: ...@@ -1085,7 +1085,7 @@ class LLM:
def _run_engine( def _run_engine(
self, *, use_tqdm: bool self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests() num_requests = self.llm_engine.get_num_unfinished_requests()
...@@ -1098,7 +1098,7 @@ class LLM: ...@@ -1098,7 +1098,7 @@ class LLM:
) )
# Run the engine. # Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
total_in_toks = 0 total_in_toks = 0
total_out_toks = 0 total_out_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
......
...@@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, ...@@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo) ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
def _get_embedding( def _get_embedding(
output: EmbeddingOutput, output: PoolingOutput,
encoding_format: Literal["float", "base64"], encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]: ) -> Union[List[float], str]:
if encoding_format == "float": if encoding_format == "float":
...@@ -40,7 +40,7 @@ def _get_embedding( ...@@ -40,7 +40,7 @@ def _get_embedding(
def request_output_to_embedding_response( def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str, final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str, created_time: int, model_name: str,
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse: encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = [] data: List[EmbeddingResponseData] = []
...@@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try: try:
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
...@@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing):
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
try: try:
async for i, res in result_generator: async for i, res in result_generator:
...@@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing):
assert all(final_res is not None for final_res in final_res_batch) assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[EmbeddingRequestOutput], final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch) final_res_batch)
response = request_output_to_embedding_response( response = request_output_to_embedding_response(
......
...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, ...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators, random_uuid from vllm.utils import make_async, merge_async_iterators, random_uuid
...@@ -21,7 +21,7 @@ logger = init_logger(__name__) ...@@ -21,7 +21,7 @@ logger = init_logger(__name__)
def request_output_to_score_response( def request_output_to_score_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str, final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse: created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = [] data: List[ScoreResponseData] = []
score = None score = None
...@@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing): ...@@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
input_pairs = make_pairs(request.text_1, request.text_2) input_pairs = make_pairs(request.text_1, request.text_2)
...@@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing): ...@@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing):
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
try: try:
...@@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing): ...@@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing):
assert all(final_res is not None for final_res in final_res_batch) assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[EmbeddingRequestOutput], final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch) final_res_batch)
response = request_output_to_score_response( response = request_output_to_score_response(
......
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora, SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp) supports_multimodal, supports_pp)
from .interfaces_base import (VllmModelForEmbedding, from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
VllmModelForTextGeneration, is_embedding_model, is_pooling_model, is_text_generation_model)
is_text_generation_model)
from .registry import ModelRegistry from .registry import ModelRegistry
__all__ = [ __all__ = [
"ModelRegistry", "ModelRegistry",
"VllmModelForEmbedding", "VllmModelForPooling",
"is_embedding_model", "is_pooling_model",
"VllmModelForTextGeneration", "VllmModelForTextGeneration",
"is_text_generation_model", "is_text_generation_model",
"HasInnerState", "HasInnerState",
...@@ -20,4 +19,4 @@ __all__ = [ ...@@ -20,4 +19,4 @@ __all__ = [
"supports_multimodal", "supports_multimodal",
"SupportsPP", "SupportsPP",
"supports_pp", "supports_pp",
] ]
\ No newline at end of file
...@@ -4,7 +4,7 @@ from typing import Any, TypeVar ...@@ -4,7 +4,7 @@ from typing import Any, TypeVar
import torch import torch
import torch.nn as nn import torch.nn as nn
from .interfaces_base import VllmModelForEmbedding, is_embedding_model from .interfaces_base import VllmModelForPooling, is_pooling_model
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=type[nn.Module])
...@@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module]) ...@@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
def as_embedding_model(cls: _T) -> _T: def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings.""" """Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models # Avoid modifying existing embedding models
if is_embedding_model(cls): if is_pooling_model(cls):
return cls return cls
# Lazy import # Lazy import
...@@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T: ...@@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T:
from .utils import AutoWeightsLoader, WeightsMapper from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForEmbedding): class ModelForEmbedding(cls, VllmModelForPooling):
def __init__( def __init__(
self, self,
......
...@@ -7,7 +7,7 @@ from typing_extensions import TypeIs, TypeVar ...@@ -7,7 +7,7 @@ from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import supports_kw from vllm.utils import supports_kw
from .interfaces_base import is_embedding_model from .interfaces_base import is_pooling_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
...@@ -389,4 +389,4 @@ def _supports_cross_encoding( ...@@ -389,4 +389,4 @@ def _supports_cross_encoding(
def supports_cross_encoding( def supports_cross_encoding(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
return is_embedding_model(model) and _supports_cross_encoding(model) return is_pooling_model(model) and _supports_cross_encoding(model)
...@@ -141,7 +141,7 @@ def is_text_generation_model( ...@@ -141,7 +141,7 @@ def is_text_generation_model(
@runtime_checkable @runtime_checkable
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]): class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]):
def pooler( def pooler(
self, self,
...@@ -153,23 +153,22 @@ class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]): ...@@ -153,23 +153,22 @@ class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
@overload @overload
def is_embedding_model( def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]:
model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
... ...
@overload @overload
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]: def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
... ...
def is_embedding_model( def is_pooling_model(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]: ) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
if not is_vllm_model(model): if not is_vllm_model(model):
return False return False
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, VllmModelForEmbedding) return isinstance(model, VllmModelForPooling)
return isinstance(model, VllmModelForEmbedding) return isinstance(model, VllmModelForPooling)
...@@ -24,7 +24,7 @@ from .adapters import as_embedding_model ...@@ -24,7 +24,7 @@ from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free, from .interfaces import (has_inner_state, is_attention_free,
supports_cross_encoding, supports_multimodal, supports_cross_encoding, supports_multimodal,
supports_pp) supports_pp)
from .interfaces_base import is_embedding_model, is_text_generation_model from .interfaces_base import is_pooling_model, is_text_generation_model
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { ...@@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class _ModelInfo: class _ModelInfo:
architecture: str architecture: str
is_text_generation_model: bool is_text_generation_model: bool
is_embedding_model: bool is_pooling_model: bool
supports_cross_encoding: bool supports_cross_encoding: bool
supports_multimodal: bool supports_multimodal: bool
supports_pp: bool supports_pp: bool
...@@ -220,19 +220,19 @@ class _ModelInfo: ...@@ -220,19 +220,19 @@ class _ModelInfo:
@staticmethod @staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_embedding_model_ = is_embedding_model(model) is_pooling_model_ = is_pooling_model(model)
if not is_embedding_model_: if not is_pooling_model_:
try: try:
as_embedding_model(model) as_embedding_model(model)
except Exception: except Exception:
pass pass
else: else:
is_embedding_model_ = True is_pooling_model_ = True
return _ModelInfo( return _ModelInfo(
architecture=model.__name__, architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model), is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model_, is_pooling_model=is_pooling_model_,
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model), supports_pp=supports_pp(model),
...@@ -441,12 +441,12 @@ class _ModelRegistry: ...@@ -441,12 +441,12 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model return model_cls.is_text_generation_model
def is_embedding_model( def is_pooling_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, List[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_embedding_model return model_cls.is_pooling_model
def is_cross_encoder_model( def is_cross_encoder_model(
self, self,
......
...@@ -53,8 +53,8 @@ class CompletionOutput: ...@@ -53,8 +53,8 @@ class CompletionOutput:
@dataclass @dataclass
class EmbeddingOutput: class PoolingOutput:
"""The output data of one completion output of a request. """The output data of one pooling output of a request.
Args: Args:
embedding: The embedding vector, which is a list of floats. The embedding: The embedding vector, which is a list of floats. The
...@@ -63,7 +63,7 @@ class EmbeddingOutput: ...@@ -63,7 +63,7 @@ class EmbeddingOutput:
embedding: List[float] embedding: List[float]
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"EmbeddingOutput(" return (f"PoolingOutput("
f"embedding={len(self.embedding)})") f"embedding={len(self.embedding)})")
...@@ -316,18 +316,18 @@ class RequestOutput: ...@@ -316,18 +316,18 @@ class RequestOutput:
f"multi_modal_placeholders={self.multi_modal_placeholders})") f"multi_modal_placeholders={self.multi_modal_placeholders})")
class EmbeddingRequestOutput: class PoolingRequestOutput:
""" """
The output data of an embedding request to the LLM. The output data of a pooling request to the LLM.
Args: Args:
request_id (str): A unique identifier for the embedding request. request_id (str): A unique identifier for the pooling request.
outputs (EmbeddingOutput): The embedding results for the given input. outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt. prompt_token_ids (List[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the embedding is completed. finished (bool): A flag indicating whether the pooling is completed.
""" """
def __init__(self, request_id: str, outputs: "EmbeddingOutput", def __init__(self, request_id: str, outputs: "PoolingOutput",
prompt_token_ids: List[int], finished: bool): prompt_token_ids: List[int], finished: bool):
self.request_id = request_id self.request_id = request_id
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
...@@ -336,11 +336,11 @@ class EmbeddingRequestOutput: ...@@ -336,11 +336,11 @@ class EmbeddingRequestOutput:
@classmethod @classmethod
def from_seq_group(cls, def from_seq_group(cls,
seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": seq_group: 'SequenceGroup') -> "PoolingRequestOutput":
if seq_group.embeddings is None: if seq_group.embeddings is None:
raise ValueError( raise ValueError(
"Embeddings are missing in seq_group for EmbeddingRequest.") "Embeddings are missing in seq_group for EmbeddingRequest.")
output = EmbeddingOutput(seq_group.embeddings) output = PoolingOutput(seq_group.embeddings)
prompt_token_ids = seq_group.prompt_token_ids prompt_token_ids = seq_group.prompt_token_ids
finished = seq_group.is_finished() finished = seq_group.is_finished()
...@@ -348,15 +348,15 @@ class EmbeddingRequestOutput: ...@@ -348,15 +348,15 @@ class EmbeddingRequestOutput:
def __repr__(self): def __repr__(self):
""" """
Returns a string representation of an EmbeddingRequestOutput instance. Returns a string representation of an PoolingRequestOutput instance.
The representation includes the request_id and the number of outputs, The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results. providing a quick overview of the pooling request's results.
Returns: Returns:
str: A string representation of the EmbeddingRequestOutput instance. str: A string representation of the PoolingRequestOutput instance.
""" """
return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " return (f"PoolingRequestOutput(request_id='{self.request_id}', "
f"outputs={repr(self.outputs)}, " f"outputs={repr(self.outputs)}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"finished={self.finished})") f"finished={self.finished})")
...@@ -415,7 +415,30 @@ class RequestOutputFactory: ...@@ -415,7 +415,30 @@ class RequestOutputFactory:
# Determine the type based on a condition, for example: # Determine the type based on a condition, for example:
if hasattr(seq_group, if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None: 'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group) return PoolingRequestOutput.from_seq_group(seq_group)
else: else:
return RequestOutput.from_seq_group(seq_group, use_cache, return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group) seq_id_to_seq_group)
def __getattr__(name: str):
import warnings
if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingOutput
if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingRequestOutput
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
...@@ -9,7 +9,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType ...@@ -9,7 +9,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -133,7 +133,7 @@ class AsyncLLM(EngineClient): ...@@ -133,7 +133,7 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
if self.detokenizer.is_request_active(request_id): if self.detokenizer.is_request_active(request_id):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment