Unverified Commit d2f058e7 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Rename embedding classes to pooling (#10801)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f877a7d1
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
STOP_ITERATION = Exception() # Sentinel
......@@ -16,7 +16,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
......@@ -32,7 +32,7 @@ class AsyncStream:
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
finished = False
try:
while True:
......
......@@ -16,12 +16,12 @@ from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
@dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
"""
Used by the CPUEmbeddingModelRunner.
Used by the CPUPoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class CPUEmbeddingModelRunner(
class CPUPoolingModelRunner(
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
ModelInputForCPUWithPoolingMetadata)
......
......@@ -14,9 +14,9 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)
......@@ -164,7 +164,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.task == "embedding":
ModelRunnerClass = CPUEmbeddingModelRunner
ModelRunnerClass = CPUPoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
......
......@@ -21,12 +21,12 @@ logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
"""
Used by the EmbeddingModelRunner.
Used by the PoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class EmbeddingModelRunner(
class PoolingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
......@@ -52,7 +52,7 @@ class EmbeddingModelRunner(
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"EmbeddingModelRunner does not support multi-step execution.")
"PoolingModelRunner does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
......
......@@ -22,9 +22,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
......@@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner
ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment