Unverified Commit bbf55c48 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[VLM] Refactor `MultiModalConfig` initialization and profiling (#7530)

parent 1ef13cf9
...@@ -86,8 +86,12 @@ def server_function(port): ...@@ -86,8 +86,12 @@ def server_function(port):
ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel) ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
with patch("vllm.entrypoints.chat_utils._mm_token_str", with patch(
lambda *_, **__: "_"): "vllm.entrypoints.chat_utils._mm_token_str",
lambda *_, **__: "_"), patch(
"vllm.model_executor.models.ModelRegistry.is_multimodal_model"
) as mock:
mock.return_value = True
sys.argv = ["placeholder.py"] + \ sys.argv = ["placeholder.py"] + \
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 " (f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
"--dtype bfloat16 --enforce-eager --api-key token-abc123 " "--dtype bfloat16 --enforce-eager --api-key token-abc123 "
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
import pytest import pytest
from transformers import CLIPImageProcessor, LlavaNextImageProcessor from transformers import CLIPImageProcessor, LlavaNextImageProcessor
from vllm.config import ModelConfig, MultiModalConfig from vllm.config import ModelConfig
from vllm.multimodal import MultiModalRegistry from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
...@@ -30,10 +30,10 @@ def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor): ...@@ -30,10 +30,10 @@ def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
seed=0, seed=0,
dtype=dtype, dtype=dtype,
revision=None, revision=None,
limit_mm_per_prompt={"image": 1},
) )
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config) mm_registry.init_mm_limits_per_prompt(model_config)
for asset in image_assets: for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor) image = rescale_image_size(asset.pil_image, size_factor)
...@@ -73,10 +73,10 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype, ...@@ -73,10 +73,10 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype,
seed=0, seed=0,
dtype=dtype, dtype=dtype,
revision=None, revision=None,
limit_mm_per_prompt={"image": 1},
) )
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config) mm_registry.init_mm_limits_per_prompt(model_config)
for asset in image_assets: for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor) image = rescale_image_size(asset.pil_image, size_factor)
...@@ -115,10 +115,10 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): ...@@ -115,10 +115,10 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
seed=0, seed=0,
dtype="half", dtype="half",
revision=None, revision=None,
limit_mm_per_prompt={"image": limit},
) )
mm_config = MultiModalConfig(limit_per_prompt={"image": limit})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config) mm_registry.init_mm_limits_per_prompt(model_config)
image = image_assets[0].pil_image image = image_assets[0].pil_image
if num_images == 0: if num_images == 0:
...@@ -145,10 +145,10 @@ def test_image_mapper_multi(image_assets, mm_registry, num_images): ...@@ -145,10 +145,10 @@ def test_image_mapper_multi(image_assets, mm_registry, num_images):
seed=0, seed=0,
dtype="half", dtype="half",
revision=None, revision=None,
limit_mm_per_prompt={"image": num_images},
) )
mm_config = MultiModalConfig(limit_per_prompt={"image": num_images})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config) mm_registry.init_mm_limits_per_prompt(model_config)
image = image_assets[0].pil_image image = image_assets[0].pil_image
mm_inputs = {"image": [image] * num_images} mm_inputs = {"image": [image] * num_images}
......
...@@ -109,6 +109,8 @@ class ModelConfig: ...@@ -109,6 +109,8 @@ class ModelConfig:
matches the model name exposed via the APIs. If multiple model matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified, names provided, the first name will be used. If not specified,
the model name will be the same as `model`. the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data instances per modality
per prompt. Only applicable for multimodal models.
""" """
def __init__( def __init__(
...@@ -134,7 +136,7 @@ class ModelConfig: ...@@ -134,7 +136,7 @@ class ModelConfig:
disable_sliding_window: bool = False, disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None, served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["MultiModalConfig"] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -211,14 +213,29 @@ class ModelConfig: ...@@ -211,14 +213,29 @@ class ModelConfig:
sliding_window_len=self.get_hf_config_sliding_window()) sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model, self.served_model_name = get_served_model_name(model,
served_model_name) served_model_name)
self.multimodal_config = multimodal_config self.multimodal_config = self._init_multimodal_config(
limit_mm_per_prompt)
if not self.skip_tokenizer_init: if not self.skip_tokenizer_init:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self._verify_embedding_mode() self._verify_embedding_mode()
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
) -> Optional["MultiModalConfig"]:
architectures = getattr(self.hf_config, "architectures", [])
if any(
ModelRegistry.is_multimodal_model(arch)
for arch in architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
else:
if limit_mm_per_prompt:
raise ValueError(
"limit_mm_per_prompt is only supported for multimodal "
"models.")
return None
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower() tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]: if tokenizer_mode not in ["auto", "slow"]:
...@@ -467,6 +484,18 @@ class ModelConfig: ...@@ -467,6 +484,18 @@ class ModelConfig:
if t != "attention" if t != "attention"
]) ])
def get_multimodal_config(self) -> "MultiModalConfig":
"""
Get the multimodal configuration of the model.
Raises:
ValueError: If the model is not multimodal.
"""
if self.multimodal_config is None:
raise ValueError("The model is not multimodal.")
return self.multimodal_config
@property @property
def is_encoder_decoder_model(self) -> bool: def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag.""" """Extract the HF encoder/decoder model flag."""
...@@ -1450,7 +1479,7 @@ class PromptAdapterConfig: ...@@ -1450,7 +1479,7 @@ class PromptAdapterConfig:
class MultiModalConfig: class MultiModalConfig:
"""Controls the behavior of multimodal models.""" """Controls the behavior of multimodal models."""
limit_per_prompt: Mapping[str, int] limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
""" """
The maximum number of multi-modal input instances allowed per prompt The maximum number of multi-modal input instances allowed per prompt
for each :class:`~vllm.multimodal.MultiModalPlugin`. for each :class:`~vllm.multimodal.MultiModalPlugin`.
...@@ -1710,7 +1739,6 @@ class EngineConfig: ...@@ -1710,7 +1739,6 @@ class EngineConfig:
device_config: DeviceConfig device_config: DeviceConfig
load_config: LoadConfig load_config: LoadConfig
lora_config: Optional[LoRAConfig] lora_config: Optional[LoRAConfig]
multimodal_config: Optional[MultiModalConfig]
speculative_config: Optional[SpeculativeConfig] speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig] decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig] observability_config: Optional[ObservabilityConfig]
......
...@@ -7,7 +7,7 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, ...@@ -7,7 +7,7 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig) SpeculativeConfig, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
...@@ -765,9 +765,6 @@ class EngineArgs: ...@@ -765,9 +765,6 @@ class EngineArgs:
"CPU offload space must be non-negative" "CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}") f", but got {self.cpu_offload_gb}")
multimodal_config = MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt or {})
device_config = DeviceConfig(device=self.device) device_config = DeviceConfig(device=self.device)
model_config = ModelConfig( model_config = ModelConfig(
model=self.model, model=self.model,
...@@ -791,7 +788,8 @@ class EngineArgs: ...@@ -791,7 +788,8 @@ class EngineArgs:
disable_sliding_window=self.disable_sliding_window, disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init, skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
multimodal_config=multimodal_config) limit_mm_per_prompt=self.limit_mm_per_prompt,
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
...@@ -970,7 +968,6 @@ class EngineArgs: ...@@ -970,7 +968,6 @@ class EngineArgs:
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
device_config=device_config, device_config=device_config,
lora_config=lora_config, lora_config=lora_config,
multimodal_config=multimodal_config,
speculative_config=speculative_config, speculative_config=speculative_config,
load_config=load_config, load_config=load_config,
decoding_config=decoding_config, decoding_config=decoding_config,
......
...@@ -10,7 +10,7 @@ from typing_extensions import assert_never ...@@ -10,7 +10,7 @@ from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
...@@ -100,8 +100,6 @@ class LLMEngine: ...@@ -100,8 +100,6 @@ class LLMEngine:
scheduler_config: The configuration related to the request scheduler. scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device. device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA. lora_config (Optional): The configuration related to serving multi-LoRA.
multimodal_config (Optional): The configuration related to multimodal
models.
speculative_config (Optional): The configuration related to speculative speculative_config (Optional): The configuration related to speculative
decoding. decoding.
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
...@@ -172,7 +170,6 @@ class LLMEngine: ...@@ -172,7 +170,6 @@ class LLMEngine:
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig], decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig], observability_config: Optional[ObservabilityConfig],
...@@ -235,7 +232,6 @@ class LLMEngine: ...@@ -235,7 +232,6 @@ class LLMEngine:
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
...@@ -278,7 +274,6 @@ class LLMEngine: ...@@ -278,7 +274,6 @@ class LLMEngine:
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
device_config=device_config, device_config=device_config,
lora_config=lora_config, lora_config=lora_config,
multimodal_config=multimodal_config,
speculative_config=speculative_config, speculative_config=speculative_config,
load_config=load_config, load_config=load_config,
prompt_adapter_config=prompt_adapter_config, prompt_adapter_config=prompt_adapter_config,
......
...@@ -141,7 +141,6 @@ class CPUExecutor(ExecutorBase): ...@@ -141,7 +141,6 @@ class CPUExecutor(ExecutorBase):
rank=rank, rank=rank,
distributed_init_method=self.distributed_init_method, distributed_init_method=self.distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config, prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
......
...@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod ...@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -29,7 +29,6 @@ class ExecutorBase(ABC): ...@@ -29,7 +29,6 @@ class ExecutorBase(ABC):
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig], prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig], observability_config: Optional[ObservabilityConfig],
...@@ -41,7 +40,6 @@ class ExecutorBase(ABC): ...@@ -41,7 +40,6 @@ class ExecutorBase(ABC):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.multimodal_config = multimodal_config
self.speculative_config = speculative_config self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config self.observability_config = observability_config
......
...@@ -55,7 +55,6 @@ class GPUExecutor(ExecutorBase): ...@@ -55,7 +55,6 @@ class GPUExecutor(ExecutorBase):
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
speculative_config=self.speculative_config, speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config, prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config) is_driver_worker=(not self.parallel_config)
......
...@@ -49,7 +49,6 @@ class OpenVINOExecutor(ExecutorBase): ...@@ -49,7 +49,6 @@ class OpenVINOExecutor(ExecutorBase):
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True, is_driver_worker=True,
) )
......
...@@ -7,9 +7,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, ...@@ -7,9 +7,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, ParallelConfig, PromptAdapterConfig,
PromptAdapterConfig, SchedulerConfig, SchedulerConfig, SpeculativeConfig)
SpeculativeConfig)
from vllm.executor.distributed_gpu_executor import ( # yapf: disable from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync) DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
...@@ -46,7 +45,6 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -46,7 +45,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig], prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
...@@ -61,7 +59,6 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -61,7 +59,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
...@@ -203,7 +200,6 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -203,7 +200,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
)) ))
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
......
...@@ -52,7 +52,6 @@ class TPUExecutor(ExecutorBase): ...@@ -52,7 +52,6 @@ class TPUExecutor(ExecutorBase):
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
) )
......
...@@ -3,9 +3,8 @@ from typing import List, Optional ...@@ -3,9 +3,8 @@ from typing import List, Optional
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, ParallelConfig, PromptAdapterConfig,
PromptAdapterConfig, SchedulerConfig, SchedulerConfig, SpeculativeConfig)
SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -29,7 +28,6 @@ class XPUExecutor(GPUExecutor): ...@@ -29,7 +28,6 @@ class XPUExecutor(GPUExecutor):
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig], prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
...@@ -46,7 +44,6 @@ class XPUExecutor(GPUExecutor): ...@@ -46,7 +44,6 @@ class XPUExecutor(GPUExecutor):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None self.speculative_config = None
......
...@@ -13,7 +13,7 @@ from vllm.logger import init_logger ...@@ -13,7 +13,7 @@ from vllm.logger import init_logger
from .data import LLMInputs from .data import LLMInputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, MultiModalConfig from vllm.config import ModelConfig
from vllm.multimodal import MultiModalDataDict, MultiModalRegistry from vllm.multimodal import MultiModalDataDict, MultiModalRegistry
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
...@@ -32,20 +32,6 @@ class InputContext: ...@@ -32,20 +32,6 @@ class InputContext:
model_config: "ModelConfig" model_config: "ModelConfig"
"""The configuration of the model.""" """The configuration of the model."""
def get_multimodal_config(self) -> "MultiModalConfig":
"""
Get the multimodal configuration of the model.
Raises:
ValueError: If the model is not multimodal.
"""
multimodal_config = self.model_config.multimodal_config
if multimodal_config is None:
raise ValueError("No multimodal config found")
return multimodal_config
def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
""" """
Get the HuggingFace configuration Get the HuggingFace configuration
......
...@@ -3,8 +3,7 @@ from typing import Optional ...@@ -3,8 +3,7 @@ from typing import Optional
from torch import nn from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, ParallelConfig, SchedulerConfig)
SchedulerConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader, from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader) get_model_loader)
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
...@@ -15,13 +14,11 @@ def get_model(*, model_config: ModelConfig, load_config: LoadConfig, ...@@ -15,13 +14,11 @@ def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig, device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config) loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config, return loader.load_model(model_config=model_config,
device_config=device_config, device_config=device_config,
lora_config=lora_config, lora_config=lora_config,
multimodal_config=multimodal_config,
parallel_config=parallel_config, parallel_config=parallel_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
cache_config=cache_config) cache_config=cache_config)
......
...@@ -132,9 +132,7 @@ def _get_model_initialization_kwargs( ...@@ -132,9 +132,7 @@ def _get_model_initialization_kwargs(
"please open an issue on github.") "please open an issue on github.")
if supports_multimodal(model_class): if supports_multimodal(model_class):
if multimodal_config is None: assert multimodal_config is not None
raise ValueError("Provide multi-modal related configurations "
"through LLM entrypoint or engine arguments.")
extra_kwargs["multimodal_config"] = multimodal_config extra_kwargs["multimodal_config"] = multimodal_config
...@@ -164,7 +162,6 @@ def _initialize_model( ...@@ -164,7 +162,6 @@ def _initialize_model(
model_config: ModelConfig, model_config: ModelConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig, cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module: scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
...@@ -173,10 +170,10 @@ def _initialize_model( ...@@ -173,10 +170,10 @@ def _initialize_model(
return build_model( return build_model(
model_class, model_class,
model_config.hf_config, model_config.hf_config,
cache_config=cache_config,
quant_config=_get_quantization_config(model_config, load_config), quant_config=_get_quantization_config(model_config, load_config),
lora_config=lora_config, lora_config=lora_config,
multimodal_config=multimodal_config, multimodal_config=model_config.multimodal_config,
cache_config=cache_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
) )
...@@ -191,7 +188,6 @@ class BaseModelLoader(ABC): ...@@ -191,7 +188,6 @@ class BaseModelLoader(ABC):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
...@@ -336,7 +332,6 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -336,7 +332,6 @@ class DefaultModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
...@@ -344,8 +339,8 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -344,8 +339,8 @@ class DefaultModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config, lora_config, cache_config,
cache_config, scheduler_config) scheduler_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(model_config.model, self._get_weights_iterator(model_config.model,
model_config.revision, model_config.revision,
...@@ -379,15 +374,14 @@ class DummyModelLoader(BaseModelLoader): ...@@ -379,15 +374,14 @@ class DummyModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config, lora_config, cache_config,
cache_config, scheduler_config) scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model)
...@@ -420,7 +414,6 @@ class TensorizerLoader(BaseModelLoader): ...@@ -420,7 +414,6 @@ class TensorizerLoader(BaseModelLoader):
model_config: ModelConfig, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig, cache_config: CacheConfig,
) -> nn.Module: ) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU. """Load a serialized model with tensorizer to the CPU.
...@@ -433,8 +426,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -433,8 +426,7 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config, lora_config, cache_config)
cache_config)
model.load_weights(self._get_weights_iterator()) model.load_weights(self._get_weights_iterator())
return model.eval() return model.eval()
...@@ -444,7 +436,6 @@ class TensorizerLoader(BaseModelLoader): ...@@ -444,7 +436,6 @@ class TensorizerLoader(BaseModelLoader):
model_config: ModelConfig, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig, cache_config: CacheConfig,
) -> nn.Module: ) -> nn.Module:
"""Load a serialized model with tensorizer. """Load a serialized model with tensorizer.
...@@ -458,7 +449,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -458,7 +449,7 @@ class TensorizerLoader(BaseModelLoader):
quant_config = _get_quantization_config( quant_config = _get_quantization_config(
model_config, self.load_config) model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs( extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, multimodal_config) model_class, lora_config, model_config.multimodal_config)
extra_kwargs["quant_config"] = quant_config extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config extra_kwargs["cache_config"] = cache_config
...@@ -473,7 +464,6 @@ class TensorizerLoader(BaseModelLoader): ...@@ -473,7 +464,6 @@ class TensorizerLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
...@@ -487,11 +477,9 @@ class TensorizerLoader(BaseModelLoader): ...@@ -487,11 +477,9 @@ class TensorizerLoader(BaseModelLoader):
if is_vllm_tensorized(self.tensorizer_config): if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config, return self._load_model_serialized(model_config, device_config,
lora_config, multimodal_config, lora_config, cache_config)
cache_config)
return self._load_model_serialized_cpu(model_config, device_config, return self._load_model_serialized_cpu(model_config, device_config,
lora_config, multimodal_config, lora_config, cache_config)
cache_config)
@staticmethod @staticmethod
def save_model( def save_model(
...@@ -577,7 +565,6 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -577,7 +565,6 @@ class ShardedStateLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
...@@ -591,8 +578,7 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -591,8 +578,7 @@ class ShardedStateLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config, lora_config, cache_config)
cache_config)
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
pattern = os.path.join( pattern = os.path.join(
local_model_path, local_model_path,
...@@ -955,15 +941,13 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -955,15 +941,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config, lora_config, cache_config)
cache_config)
self._load_weights(model_config, model) self._load_weights(model_config, model)
...@@ -1032,7 +1016,6 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -1032,7 +1016,6 @@ class GGUFModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
...@@ -1047,8 +1030,7 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -1047,8 +1030,7 @@ class GGUFModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config, lora_config, cache_config)
cache_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map)) self._get_weights_iterator(local_model_path, gguf_weights_map))
return model return model
......
...@@ -9,17 +9,12 @@ from vllm.utils import is_hip ...@@ -9,17 +9,12 @@ from vllm.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
# Architecture -> (module, class).
_GENERATION_MODELS = { _GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
...@@ -28,7 +23,6 @@ _GENERATION_MODELS = { ...@@ -28,7 +23,6 @@ _GENERATION_MODELS = {
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
...@@ -37,13 +31,8 @@ _GENERATION_MODELS = { ...@@ -37,13 +31,8 @@ _GENERATION_MODELS = {
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration":
("llava_next", "LlavaNextForConditionalGeneration"),
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
...@@ -53,17 +42,13 @@ _GENERATION_MODELS = { ...@@ -53,17 +42,13 @@ _GENERATION_MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"), "MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
...@@ -83,6 +68,22 @@ _EMBEDDING_MODELS = { ...@@ -83,6 +68,22 @@ _EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
} }
_MULTIMODAL_MODELS = {
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration":
("llava_next", "LlavaNextForConditionalGeneration"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
}
_CONDITIONAL_GENERATION_MODELS = { _CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
...@@ -91,7 +92,8 @@ _CONDITIONAL_GENERATION_MODELS = { ...@@ -91,7 +92,8 @@ _CONDITIONAL_GENERATION_MODELS = {
_MODELS = { _MODELS = {
**_GENERATION_MODELS, **_GENERATION_MODELS,
**_EMBEDDING_MODELS, **_EMBEDDING_MODELS,
**_CONDITIONAL_GENERATION_MODELS **_MULTIMODAL_MODELS,
**_CONDITIONAL_GENERATION_MODELS,
} }
# Architecture -> type. # Architecture -> type.
...@@ -182,6 +184,15 @@ class ModelRegistry: ...@@ -182,6 +184,15 @@ class ModelRegistry:
def is_embedding_model(model_arch: str) -> bool: def is_embedding_model(model_arch: str) -> bool:
return model_arch in _EMBEDDING_MODELS return model_arch in _EMBEDDING_MODELS
@staticmethod
def is_multimodal_model(model_arch: str) -> bool:
# TODO: find a way to avoid initializing CUDA prematurely to
# use `supports_multimodal` to determine if a model is multimodal
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
# from vllm.model_executor.models.interfaces import supports_multimodal
return model_arch in _MULTIMODAL_MODELS
__all__ = [ __all__ = [
"ModelRegistry", "ModelRegistry",
......
...@@ -2,7 +2,7 @@ import functools ...@@ -2,7 +2,7 @@ import functools
from collections import UserDict from collections import UserDict
from typing import Dict, Mapping, Optional, Sequence from typing import Dict, Mapping, Optional, Sequence
from vllm.config import ModelConfig, MultiModalConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .audio import AudioPlugin from .audio import AudioPlugin
...@@ -181,7 +181,6 @@ class MultiModalRegistry: ...@@ -181,7 +181,6 @@ class MultiModalRegistry:
def init_mm_limits_per_prompt( def init_mm_limits_per_prompt(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
multimodal_config: Optional[MultiModalConfig],
) -> None: ) -> None:
""" """
Initialize the maximum number of multi-modal input instances for each Initialize the maximum number of multi-modal input instances for each
...@@ -192,6 +191,7 @@ class MultiModalRegistry: ...@@ -192,6 +191,7 @@ class MultiModalRegistry:
"`mm_limits` has already been set for model=%s, and will " "`mm_limits` has already been set for model=%s, and will "
"be overwritten by the new values.", model_config.model) "be overwritten by the new values.", model_config.model)
multimodal_config = model_config.multimodal_config
if multimodal_config is None: if multimodal_config is None:
limits_per_plugin = self._disabled_limits_per_plugin limits_per_plugin = self._disabled_limits_per_plugin
else: else:
......
...@@ -23,8 +23,8 @@ except ImportError: ...@@ -23,8 +23,8 @@ except ImportError:
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs from vllm.multimodal import MultiModalInputs
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
...@@ -66,7 +66,6 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -66,7 +66,6 @@ class TP1DraftModelRunner(ModelRunner):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
multimodal_config: Optional[MultiModalConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None, observability_config: Optional[ObservabilityConfig] = None,
...@@ -86,7 +85,6 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -86,7 +85,6 @@ class TP1DraftModelRunner(ModelRunner):
lora_config=lora_config, lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config, prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states, return_hidden_states=return_hidden_states,
observability_config=observability_config, observability_config=observability_config,
......
from typing import List, Optional from typing import List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner) ModelRunner)
...@@ -31,7 +31,6 @@ class TargetModelRunner(ModelRunner): ...@@ -31,7 +31,6 @@ class TargetModelRunner(ModelRunner):
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None): observability_config: Optional[ObservabilityConfig] = None):
# An internal boolean member variable to indicate if token log # An internal boolean member variable to indicate if token log
...@@ -47,7 +46,6 @@ class TargetModelRunner(ModelRunner): ...@@ -47,7 +46,6 @@ class TargetModelRunner(ModelRunner):
lora_config=lora_config, lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config, prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states, return_hidden_states=return_hidden_states,
observability_config=observability_config, observability_config=observability_config,
......
...@@ -6,8 +6,8 @@ from torch import nn ...@@ -6,8 +6,8 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, ParallelConfig, PromptAdapterConfig,
PromptAdapterConfig, SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -79,7 +79,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -79,7 +79,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
...@@ -94,7 +93,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -94,7 +93,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
...@@ -125,7 +123,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -125,7 +123,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.model = get_model(model_config=self.model_config, self.model = get_model(model_config=self.model_config,
load_config=self.load_config, load_config=self.load_config,
device_config=self.device_config, device_config=self.device_config,
multimodal_config=self.multimodal_config,
lora_config=self.lora_config, lora_config=self.lora_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment