Unverified Commit d2f058e7 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Rename embedding classes to pooling (#10801)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f877a7d1
import asyncio import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
class AsyncStream: class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request """A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator.""" that can be iterated over asynchronously via an async generator."""
STOP_ITERATION = Exception() # Sentinel STOP_ITERATION = Exception() # Sentinel
...@@ -16,7 +16,7 @@ class AsyncStream: ...@@ -16,7 +16,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None: Exception]) -> None:
if not self._finished: if not self._finished:
self._queue.put_nowait(item) self._queue.put_nowait(item)
...@@ -32,7 +32,7 @@ class AsyncStream: ...@@ -32,7 +32,7 @@ class AsyncStream:
async def generator( async def generator(
self self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
finished = False finished = False
try: try:
while True: while True:
......
...@@ -16,12 +16,12 @@ from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU, ...@@ -16,12 +16,12 @@ from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
""" """
Used by the CPUEmbeddingModelRunner. Used by the CPUPoolingModelRunner.
""" """
pooling_metadata: Optional["PoolingMetadata"] = None pooling_metadata: Optional["PoolingMetadata"] = None
class CPUEmbeddingModelRunner( class CPUPoolingModelRunner(
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
ModelInputForCPUWithPoolingMetadata) ModelInputForCPUWithPoolingMetadata)
......
...@@ -14,9 +14,9 @@ from vllm.logger import init_logger ...@@ -14,9 +14,9 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase, LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
...@@ -164,7 +164,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -164,7 +164,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else {"return_hidden_states": True} else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.task == "embedding": if self.model_config.task == "embedding":
ModelRunnerClass = CPUEmbeddingModelRunner ModelRunnerClass = CPUPoolingModelRunner
elif self.model_config.is_encoder_decoder: elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunnerBase = ModelRunnerClass( self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
......
...@@ -21,12 +21,12 @@ logger = init_logger(__name__) ...@@ -21,12 +21,12 @@ logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
""" """
Used by the EmbeddingModelRunner. Used by the PoolingModelRunner.
""" """
pooling_metadata: Optional["PoolingMetadata"] = None pooling_metadata: Optional["PoolingMetadata"] = None
class EmbeddingModelRunner( class PoolingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata) ModelInputForGPUWithPoolingMetadata)
...@@ -52,7 +52,7 @@ class EmbeddingModelRunner( ...@@ -52,7 +52,7 @@ class EmbeddingModelRunner(
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1: if num_steps > 1:
raise ValueError( raise ValueError(
"EmbeddingModelRunner does not support multi-step execution.") "PoolingModelRunner does not support multi-step execution.")
if self.lora_config: if self.lora_config:
assert model_input.lora_requests is not None assert model_input.lora_requests is not None
......
...@@ -22,9 +22,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -22,9 +22,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta) SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
...@@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.task == "embedding": if model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder: elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass( self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment