Unverified Commit 7eb4a51c authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Support serving encoder/decoder models (#7258)

parent 0fa14907
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable) Union, overload, runtime_checkable)
from typing_extensions import TypeGuard from typing_extensions import TypeIs
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -37,18 +37,18 @@ class _SupportsVisionType(Protocol): ...@@ -37,18 +37,18 @@ class _SupportsVisionType(Protocol):
@overload @overload
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]: def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]:
... ...
@overload @overload
def supports_vision(model: object) -> TypeGuard[SupportsVision]: def supports_vision(model: object) -> TypeIs[SupportsVision]:
... ...
def supports_vision( def supports_vision(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]: ) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsVisionType) return isinstance(model, _SupportsVisionType)
...@@ -94,18 +94,18 @@ class _SupportsLoRAType(Protocol): ...@@ -94,18 +94,18 @@ class _SupportsLoRAType(Protocol):
@overload @overload
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]: def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
... ...
@overload @overload
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]: def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
... ...
def supports_lora( def supports_lora(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
result = _supports_lora(model) result = _supports_lora(model)
if not result: if not result:
...@@ -137,7 +137,7 @@ def supports_lora( ...@@ -137,7 +137,7 @@ def supports_lora(
def _supports_lora( def _supports_lora(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsLoRAType) return isinstance(model, _SupportsLoRAType)
...@@ -172,18 +172,18 @@ class _HasInnerStateType(Protocol): ...@@ -172,18 +172,18 @@ class _HasInnerStateType(Protocol):
@overload @overload
def has_inner_state(model: object) -> TypeGuard[HasInnerState]: def has_inner_state(model: object) -> TypeIs[HasInnerState]:
... ...
@overload @overload
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]: def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
... ...
def has_inner_state( def has_inner_state(
model: Union[Type[object], object] model: Union[Type[object], object]
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]: ) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _HasInnerStateType) return isinstance(model, _HasInnerStateType)
......
...@@ -10,6 +10,7 @@ from vllm.inputs.registry import InputContext ...@@ -10,6 +10,7 @@ from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalInputs, MultiModalPlugin from .base import MultiModalInputs, MultiModalPlugin
...@@ -113,7 +114,8 @@ class ImagePlugin(MultiModalPlugin): ...@@ -113,7 +114,8 @@ class ImagePlugin(MultiModalPlugin):
def _default_input_mapper(self, ctx: InputContext, def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs: data: object) -> MultiModalInputs:
model_config = ctx.model_config model_config = ctx.model_config
if isinstance(data, (Image.Image, list)):
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config) image_processor = self._get_hf_image_processor(model_config)
if image_processor is None: if image_processor is None:
raise RuntimeError("No HuggingFace processor is available " raise RuntimeError("No HuggingFace processor is available "
...@@ -127,7 +129,7 @@ class ImagePlugin(MultiModalPlugin): ...@@ -127,7 +129,7 @@ class ImagePlugin(MultiModalPlugin):
raise raise
return MultiModalInputs(batch_data) return MultiModalInputs(batch_data)
elif isinstance(data, torch.Tensor): elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet") raise NotImplementedError("Embeddings input is not supported yet")
raise TypeError(f"Invalid image type: {type(data)}") raise TypeError(f"Invalid image type: {type(data)}")
......
...@@ -11,7 +11,7 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, ...@@ -11,7 +11,7 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
import torch import torch
from vllm.inputs import is_valid_encoder_decoder_llm_inputs from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
......
...@@ -17,8 +17,8 @@ from collections import defaultdict ...@@ -17,8 +17,8 @@ from collections import defaultdict
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
Union, overload) Type, TypeVar, Union, overload)
from uuid import uuid4 from uuid import uuid4
import numpy as np import numpy as np
...@@ -26,12 +26,10 @@ import numpy.typing as npt ...@@ -26,12 +26,10 @@ import numpy.typing as npt
import psutil import psutil
import torch import torch
import torch.types import torch.types
from typing_extensions import ParamSpec from typing_extensions import ParamSpec, TypeIs, assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
SingletonPromptInputs)
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -812,6 +810,24 @@ def get_dtype_size(dtype: torch.dtype) -> int: ...@@ -812,6 +810,24 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size() return torch.tensor([], dtype=dtype).element_size()
# `collections` helpers
def is_list_of(
value: object,
typ: Type[T],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]:
if not isinstance(value, list):
return False
if check == "first":
return len(value) == 0 or isinstance(value[0], typ)
elif check == "all":
return all(isinstance(v, typ) for v in value)
assert_never(check)
def merge_dicts(dict1: Dict[K, List[T]], def merge_dicts(dict1: Dict[K, List[T]],
dict2: Dict[K, List[T]]) -> Dict[K, List[T]]: dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
"""Merge 2 dicts that have key -> List of items. """Merge 2 dicts that have key -> List of items.
...@@ -959,6 +975,7 @@ def enable_trace_function_call_for_thread() -> None: ...@@ -959,6 +975,7 @@ def enable_trace_function_call_for_thread() -> None:
enable_trace_function_call(log_path) enable_trace_function_call(log_path)
# `functools` helpers
def identity(value: T) -> T: def identity(value: T) -> T:
return value return value
...@@ -1080,50 +1097,3 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, ...@@ -1080,50 +1097,3 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock""" """Utility function to run async task in a lock"""
async with lock: async with lock:
return await task(*args, **kwargs) return await task(*args, **kwargs)
def is_encoder_decoder_model_config(model_config) -> bool:
'''
Extract the HF encoder/decoder model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
getattr(model_config.hf_config,
"is_encoder_decoder",
False)
def is_embedding_model_config(model_config) -> bool:
'''
Extract the embedding model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
model_config.embedding_mode
def build_explicit_enc_dec_prompt(
encoder_prompt: SingletonPromptInputs,
decoder_prompt: SingletonPromptInputs,
) -> ExplicitEncoderDecoderPrompt:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
def zip_enc_dec_prompt_lists(
enc_prompt_list: List[SingletonPromptInputs],
dec_prompt_list: List[SingletonPromptInputs],
) -> List[ExplicitEncoderDecoderPrompt]:
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt,
decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
) -> List[Tuple[PromptInputs, PromptInputs]]:
return [(enc_dec_prompt['encoder_prompt'],
enc_dec_prompt['decoder_prompt'])
for enc_dec_prompt in enc_dec_prompts]
...@@ -19,8 +19,6 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ...@@ -19,8 +19,6 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import (is_embedding_model_config,
is_encoder_decoder_model_config)
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
...@@ -113,10 +111,10 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -113,10 +111,10 @@ class Worker(LocalOrDistributedWorkerBase):
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
def _is_encoder_decoder_model(self): def _is_encoder_decoder_model(self):
return is_encoder_decoder_model_config(self.model_config) return self.model_config.is_encoder_decoder_model
def _is_embedding_model(self): def _is_embedding_model(self):
return is_embedding_model_config(self.model_config) return self.model_config.is_embedding_model
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment