Unverified Commit d9e98f42 authored by xwjiang2010's avatar xwjiang2010 Committed by GitHub
Browse files

[vlm] Remove vision language config. (#6089)


Signed-off-by: default avatarXiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 3c6325f0
...@@ -7,8 +7,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, ...@@ -7,8 +7,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
Tuple, Union) Tuple, Union)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SpeculativeConfig, VisionLanguageConfig) SchedulerConfig, SpeculativeConfig)
from vllm.executor.distributed_gpu_executor import ( # yapf: disable from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync) DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
...@@ -43,7 +43,7 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -43,7 +43,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
assert device_config.device_type == "xpu" assert device_config.device_type == "xpu"
...@@ -57,7 +57,7 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -57,7 +57,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
...@@ -199,7 +199,7 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -199,7 +199,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
)) ))
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
......
...@@ -50,7 +50,7 @@ class TPUExecutor(ExecutorBase): ...@@ -50,7 +50,7 @@ class TPUExecutor(ExecutorBase):
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
vision_language_config=self.vision_language_config, multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
) )
......
...@@ -3,8 +3,8 @@ from typing import List, Optional ...@@ -3,8 +3,8 @@ from typing import List, Optional
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SpeculativeConfig, VisionLanguageConfig) SchedulerConfig, SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -26,7 +26,7 @@ class XPUExecutor(GPUExecutor): ...@@ -26,7 +26,7 @@ class XPUExecutor(GPUExecutor):
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
assert device_config.device_type == "xpu" assert device_config.device_type == "xpu"
...@@ -42,7 +42,7 @@ class XPUExecutor(GPUExecutor): ...@@ -42,7 +42,7 @@ class XPUExecutor(GPUExecutor):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.speculative_config = None self.speculative_config = None
# Instantiate the worker and load the model to GPU. # Instantiate the worker and load the model to GPU.
......
...@@ -11,7 +11,7 @@ from vllm.logger import init_logger ...@@ -11,7 +11,7 @@ from vllm.logger import init_logger
from .data import LLMInputs from .data import LLMInputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VisionLanguageConfig from vllm.config import ModelConfig, MultiModalConfig
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
...@@ -30,7 +30,7 @@ class InputContext: ...@@ -30,7 +30,7 @@ class InputContext:
model_config: "ModelConfig" model_config: "ModelConfig"
"""The configuration of the model.""" """The configuration of the model."""
def get_multimodal_config(self) -> "VisionLanguageConfig": def get_multimodal_config(self) -> "MultiModalConfig":
""" """
Get the multimodal configuration of the model. Get the multimodal configuration of the model.
......
...@@ -3,8 +3,8 @@ from typing import Optional ...@@ -3,8 +3,8 @@ from typing import Optional
from torch import nn from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader, from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader) get_model_loader)
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
...@@ -15,13 +15,13 @@ def get_model(*, model_config: ModelConfig, load_config: LoadConfig, ...@@ -15,13 +15,13 @@ def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig, device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config) loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config, return loader.load_model(model_config=model_config,
device_config=device_config, device_config=device_config,
lora_config=lora_config, lora_config=lora_config,
vision_language_config=vision_language_config, multimodal_config=multimodal_config,
parallel_config=parallel_config, parallel_config=parallel_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
cache_config=cache_config) cache_config=cache_config)
......
...@@ -16,8 +16,8 @@ from huggingface_hub import HfApi, hf_hub_download ...@@ -16,8 +16,8 @@ from huggingface_hub import HfApi, hf_hub_download
from torch import nn from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ParallelConfig, LoRAConfig, ModelConfig, MultiModalConfig,
SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -68,7 +68,7 @@ def _get_quantization_config( ...@@ -68,7 +68,7 @@ def _get_quantization_config(
def _get_model_initialization_kwargs( def _get_model_initialization_kwargs(
model_class: Type[nn.Module], model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vlm_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Get extra kwargs for model initialization.""" """Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {} extra_kwargs: Dict[str, Any] = {}
...@@ -84,18 +84,18 @@ def _get_model_initialization_kwargs( ...@@ -84,18 +84,18 @@ def _get_model_initialization_kwargs(
"please open an issue on github.") "please open an issue on github.")
if supports_vision(model_class): if supports_vision(model_class):
if vlm_config is None: if multimodal_config is None:
raise ValueError("Provide vision related configurations " raise ValueError("Provide vision related configurations "
"through LLM entrypoint or engine arguments.") "through LLM entrypoint or engine arguments.")
extra_kwargs["vlm_config"] = vlm_config extra_kwargs["multimodal_config"] = multimodal_config
return extra_kwargs return extra_kwargs
def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
...@@ -105,7 +105,7 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, ...@@ -105,7 +105,7 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
**_get_model_initialization_kwargs( **_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)) model_class, lora_config, multimodal_config))
class BaseModelLoader(ABC): class BaseModelLoader(ABC):
...@@ -118,7 +118,7 @@ class BaseModelLoader(ABC): ...@@ -118,7 +118,7 @@ class BaseModelLoader(ABC):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
...@@ -258,14 +258,14 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -258,14 +258,14 @@ class DefaultModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config, lora_config, multimodal_config,
cache_config) cache_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(model_config.model, self._get_weights_iterator(model_config.model,
...@@ -298,14 +298,14 @@ class DummyModelLoader(BaseModelLoader): ...@@ -298,14 +298,14 @@ class DummyModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config, lora_config, multimodal_config,
cache_config) cache_config)
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
...@@ -339,7 +339,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -339,7 +339,7 @@ class TensorizerLoader(BaseModelLoader):
model_config: ModelConfig, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig, cache_config: CacheConfig,
) -> nn.Module: ) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU. """Load a serialized model with tensorizer to the CPU.
...@@ -352,7 +352,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -352,7 +352,7 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config, lora_config, multimodal_config,
cache_config) cache_config)
model.load_weights(self._get_weights_iterator()) model.load_weights(self._get_weights_iterator())
...@@ -363,7 +363,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -363,7 +363,7 @@ class TensorizerLoader(BaseModelLoader):
model_config: ModelConfig, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig, cache_config: CacheConfig,
) -> nn.Module: ) -> nn.Module:
"""Load a serialized model with tensorizer. """Load a serialized model with tensorizer.
...@@ -377,7 +377,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -377,7 +377,7 @@ class TensorizerLoader(BaseModelLoader):
quant_config = _get_quantization_config( quant_config = _get_quantization_config(
model_config, self.load_config) model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs( extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config) model_class, lora_config, multimodal_config)
extra_kwargs["quant_config"] = quant_config extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config extra_kwargs["cache_config"] = cache_config
...@@ -392,7 +392,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -392,7 +392,7 @@ class TensorizerLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
...@@ -406,12 +406,10 @@ class TensorizerLoader(BaseModelLoader): ...@@ -406,12 +406,10 @@ class TensorizerLoader(BaseModelLoader):
if is_vllm_tensorized(self.tensorizer_config): if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config, return self._load_model_serialized(model_config, device_config,
lora_config, lora_config, multimodal_config,
vision_language_config,
cache_config) cache_config)
return self._load_model_serialized_cpu(model_config, device_config, return self._load_model_serialized_cpu(model_config, device_config,
lora_config, lora_config, multimodal_config,
vision_language_config,
cache_config) cache_config)
@staticmethod @staticmethod
...@@ -494,7 +492,7 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -494,7 +492,7 @@ class ShardedStateLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
...@@ -508,7 +506,7 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -508,7 +506,7 @@ class ShardedStateLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config, lora_config, multimodal_config,
cache_config) cache_config)
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
pattern = os.path.join( pattern = os.path.join(
...@@ -804,14 +802,14 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -804,14 +802,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config, lora_config, multimodal_config,
cache_config) cache_config)
self._load_weights(model_config, model) self._load_weights(model_config, model)
......
...@@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, ...@@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
from typing_extensions import TypeGuard from typing_extensions import TypeGuard
from vllm.config import LoRAConfig, VisionLanguageConfig from vllm.config import LoRAConfig, MultiModalConfig
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -22,7 +22,7 @@ class SupportsVision(Protocol): ...@@ -22,7 +22,7 @@ class SupportsVision(Protocol):
MRO of your model class. MRO of your model class.
""" """
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None: def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
... ...
...@@ -32,7 +32,7 @@ class SupportsVision(Protocol): ...@@ -32,7 +32,7 @@ class SupportsVision(Protocol):
class _SupportsVisionType(Protocol): class _SupportsVisionType(Protocol):
supports_vision: Literal[True] supports_vision: Literal[True]
def __call__(self, *, vlm_config: VisionLanguageConfig) -> None: def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
... ...
......
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from transformers import CLIPVisionConfig, LlavaConfig from transformers import CLIPVisionConfig, LlavaConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -108,13 +108,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -108,13 +108,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self, def __init__(self,
config: LlavaConfig, config: LlavaConfig,
vlm_config: VisionLanguageConfig, multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.vlm_config = vlm_config self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config.vision_config) self.vision_tower = CLIPVisionModel(config.vision_config)
...@@ -138,14 +138,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -138,14 +138,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.sampler = Sampler() self.sampler = Sampler()
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]): if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError( raise ValueError(
f"The expected image tensor shape is batch dimension plus " "The expected image tensor shape is batch dimension plus "
f"{self.vlm_config.image_input_shape[1:]}. " "channel, height and width.")
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
return data return data
...@@ -244,7 +243,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -244,7 +243,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = merge_vision_embeddings( inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.vlm_config.image_token_id) self.config.image_token_index)
input_ids = None input_ids = None
else: else:
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -9,7 +9,7 @@ from transformers.models.llava_next.modeling_llava_next import ( ...@@ -9,7 +9,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -204,13 +204,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -204,13 +204,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self, def __init__(self,
config: LlavaNextConfig, config: LlavaNextConfig,
vlm_config: VisionLanguageConfig, multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.vlm_config = vlm_config self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config=config.vision_config) self.vision_tower = CLIPVisionModel(config=config.vision_config)
...@@ -244,6 +244,47 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -244,6 +244,47 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return data return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
dim = data.dim()
height = width = self.config.vision_config.image_size
# All 4d image tensors have the same number of patches,
# so data is a 5d batch of these tensors
if dim == 5:
if list(data.shape)[2:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (batch size, "
f"patch number, 3, {height}, {width}), got {data.shape}"
)
# 4d image tensors have different number of patches,
# so data is each individual tensor.
elif dim == 4:
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (patch "
f"number, 3, {height}, {width}), got {data.shape}")
else:
raise ValueError(
f"Invalid pixel value tensor of shape {data.shape}")
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]: self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -262,7 +303,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -262,7 +303,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return LlavaNextImagePixelInputs( return LlavaNextImagePixelInputs(
type="pixel_values", type="pixel_values",
data=pixel_values, data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes), image_sizes=self._validate_image_sizes(image_sizes),
) )
...@@ -454,7 +495,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -454,7 +495,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = merge_vision_embeddings( inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.vlm_config.image_token_id) self.config.image_token_index)
input_ids = None input_ids = None
else: else:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import re import re
from functools import lru_cache from functools import lru_cache
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
...@@ -24,7 +24,7 @@ from PIL import Image ...@@ -24,7 +24,7 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -50,6 +50,9 @@ _KEYS_TO_MODIFY_MAPPING = { ...@@ -50,6 +50,9 @@ _KEYS_TO_MODIFY_MAPPING = {
"model.vision_embed_tokens": "vision_embed_tokens", "model.vision_embed_tokens": "vision_embed_tokens",
} }
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
hidden_act="quick_gelu", hidden_act="quick_gelu",
hidden_size=1024, hidden_size=1024,
...@@ -95,13 +98,10 @@ class Phi3ImageEmbeddingBase(nn.Module): ...@@ -95,13 +98,10 @@ class Phi3ImageEmbeddingBase(nn.Module):
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform.""" """Phi3 Image embedding with HD transform."""
def __init__(self, def __init__(self, config: PretrainedConfig, wte=None) -> None:
vision_language_config: VisionLanguageConfig,
config: PretrainedConfig,
wte=None) -> None:
super().__init__(wte) super().__init__(wte)
self.image_token_id = vision_language_config.image_token_id self.image_token_id = _IMAGE_TOKEN_ID
# n_embed or hidden_size # n_embed or hidden_size
hidden_size = config.n_embd if hasattr( hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size config, 'n_embd') else config.hidden_size
...@@ -333,7 +333,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): ...@@ -333,7 +333,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG, CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len, seq_len,
image_token_id=32044, image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
mm_data = dummy_image_for_clip( mm_data = dummy_image_for_clip(
...@@ -370,7 +370,6 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -370,7 +370,6 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs return llm_inputs
model_config = ctx.model_config model_config = ctx.model_config
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(PretrainedConfig) hf_config = ctx.get_hf_config(PretrainedConfig)
image_data = multi_modal_data["image"] image_data = multi_modal_data["image"]
...@@ -407,7 +406,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -407,7 +406,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
new_token_ids: List[int] = [] new_token_ids: List[int] = []
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1): for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids: if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
new_token_ids.append(multimodal_config.image_token_id) new_token_ids.append(_IMAGE_TOKEN_ID)
# No need to further scan the list since we only replace once # No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):]) new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
...@@ -424,7 +423,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -424,7 +423,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
model_config, model_config,
CLIP_VIT_LARGE_PATCH14_336_CONFIG, CLIP_VIT_LARGE_PATCH14_336_CONFIG,
llm_inputs, llm_inputs,
image_token_id=multimodal_config.image_token_id, image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -436,25 +435,53 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -436,25 +435,53 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
vlm_config: VisionLanguageConfig, multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.vlm_config = vlm_config self.multimodal_config = multimodal_config
self.model = LlamaModel(config, cache_config, quant_config) self.model = LlamaModel(config, cache_config, quant_config)
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding( self.vision_embed_tokens = Phi3HDImageEmbedding(
vlm_config, config, self.model.embed_tokens) config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
if list(data.shape)[2:] != [
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
]:
raise ValueError(
"The expected pixel value tensor shape is batch dimension "
"plus patch number, channel, height and width.")
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]: self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -471,9 +498,10 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -471,9 +498,10 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
raise ValueError("Incorrect type of image sizes. " raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}") f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(type="pixel_values", return Phi3VImagePixelInputs(
data=pixel_values, type="pixel_values",
image_sizes=image_sizes) data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -120,3 +120,10 @@ class MultiModalRegistry: ...@@ -120,3 +120,10 @@ class MultiModalRegistry:
Create an input mapper (see :meth:`map_input`) for a specific model. Create an input mapper (see :meth:`map_input`) for a specific model.
""" """
return functools.partial(self.map_input, model_config) return functools.partial(self.map_input, model_config)
def get_num_input_tokens(self):
"""
Get the number of input tokens for profiling purposes.
"""
# TODO: Provide this number on a per model basis.
return 3000
...@@ -3,8 +3,8 @@ from typing import List, Optional ...@@ -3,8 +3,8 @@ from typing import List, Optional
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -47,7 +47,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -47,7 +47,7 @@ class TP1DraftModelRunner(ModelRunner):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
): ):
if return_hidden_states: if return_hidden_states:
...@@ -65,7 +65,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -65,7 +65,7 @@ class TP1DraftModelRunner(ModelRunner):
lora_config=lora_config, lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config, multimodal_config=multimodal_config,
return_hidden_states=return_hidden_states, return_hidden_states=return_hidden_states,
) )
......
...@@ -7,8 +7,8 @@ from torch import nn ...@@ -7,8 +7,8 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -79,7 +79,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -79,7 +79,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
*args, *args,
...@@ -93,7 +93,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -93,7 +93,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
...@@ -120,11 +120,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -120,11 +120,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model( self.model = get_model(model_config=self.model_config,
model_config=self.model_config,
load_config=self.load_config, load_config=self.load_config,
device_config=self.device_config, device_config=self.device_config,
vision_language_config=self.vision_language_config, multimodal_config=self.multimodal_config,
lora_config=self.lora_config, lora_config=self.lora_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
......
...@@ -6,8 +6,8 @@ import torch.distributed ...@@ -6,8 +6,8 @@ import torch.distributed
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -131,7 +131,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -131,7 +131,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
...@@ -145,7 +145,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -145,7 +145,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config self.lora_config = lora_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
...@@ -162,7 +162,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -162,7 +162,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
cache_config, cache_config,
load_config=self.load_config, load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
......
...@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type ...@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -40,7 +40,7 @@ class EmbeddingModelRunner( ...@@ -40,7 +40,7 @@ class EmbeddingModelRunner(
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
): ):
super().__init__(model_config, super().__init__(model_config,
parallel_config, parallel_config,
...@@ -51,7 +51,7 @@ class EmbeddingModelRunner( ...@@ -51,7 +51,7 @@ class EmbeddingModelRunner(
lora_config=lora_config, lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config) multimodal_config=multimodal_config)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
......
...@@ -24,8 +24,8 @@ except ImportError: ...@@ -24,8 +24,8 @@ except ImportError:
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig)
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
...@@ -36,7 +36,8 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager ...@@ -36,7 +36,8 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs) MultiModalInputs)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -171,7 +172,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -171,7 +172,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
...@@ -182,7 +183,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -182,7 +183,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
self.device = self.device_config.device self.device = self.device_config.device
...@@ -244,7 +245,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -244,7 +245,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
device_config=self.device_config, device_config=self.device_config,
load_config=self.load_config, load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
cache_config=self.cache_config, cache_config=self.cache_config,
...@@ -256,6 +257,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -256,6 +257,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if self.lora_config: if self.lora_config:
assert supports_lora(self.model), "Model does not support LoRA" assert supports_lora(self.model), "Model does not support LoRA"
assert not supports_vision(
self.model
), "To be tested: vision language model with LoRA settings."
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
...@@ -804,12 +808,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -804,12 +808,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# the number of seqs (batch_size) is chosen to maximize the number # the number of seqs (batch_size) is chosen to maximize the number
# of images processed. # of images processed.
model_config = self.model_config model_config = self.model_config
vlm_config = self.vision_language_config
if vlm_config: if supports_vision(self.model):
max_num_seqs = min( max_num_seqs = max(
1,
min(
max_num_seqs, max_num_seqs,
int(max_num_batched_tokens / vlm_config.image_feature_size)) int(max_num_batched_tokens /
MULTIMODAL_REGISTRY.get_num_input_tokens())))
batch_size = 0 batch_size = 0
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
......
...@@ -7,8 +7,8 @@ from torch import nn ...@@ -7,8 +7,8 @@ from torch import nn
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.openvino import get_model from vllm.model_executor.model_loader.openvino import get_model
...@@ -48,7 +48,7 @@ class OpenVINOModelRunner: ...@@ -48,7 +48,7 @@ class OpenVINOModelRunner:
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
*args, *args,
...@@ -60,7 +60,7 @@ class OpenVINOModelRunner: ...@@ -60,7 +60,7 @@ class OpenVINOModelRunner:
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
......
...@@ -7,8 +7,8 @@ import torch.distributed ...@@ -7,8 +7,8 @@ import torch.distributed
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
...@@ -148,7 +148,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): ...@@ -148,7 +148,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
...@@ -162,7 +162,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): ...@@ -162,7 +162,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config self.lora_config = lora_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
...@@ -180,7 +180,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): ...@@ -180,7 +180,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
cache_config, cache_config,
load_config=self.load_config, load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
) )
......
...@@ -8,7 +8,7 @@ import torch_xla.core.xla_model as xm ...@@ -8,7 +8,7 @@ import torch_xla.core.xla_model as xm
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) MultiModalConfig, ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -39,7 +39,7 @@ class TPUModelRunner: ...@@ -39,7 +39,7 @@ class TPUModelRunner:
device_config: DeviceConfig, device_config: DeviceConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig, load_config: LoadConfig,
vision_language_config: Optional[VisionLanguageConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
...@@ -48,7 +48,7 @@ class TPUModelRunner: ...@@ -48,7 +48,7 @@ class TPUModelRunner:
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
self.load_config = load_config self.load_config = load_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
...@@ -82,7 +82,7 @@ class TPUModelRunner: ...@@ -82,7 +82,7 @@ class TPUModelRunner:
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
cache_config=self.cache_config, cache_config=self.cache_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
vision_language_config=self.vision_language_config, multimodal_config=self.multimodal_config,
lora_config=None, lora_config=None,
) )
xm.wait_device_ops() xm.wait_device_ops()
......
...@@ -8,7 +8,7 @@ import torch_xla.runtime as xr ...@@ -8,7 +8,7 @@ import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) MultiModalConfig, ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -31,7 +31,7 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -31,7 +31,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
device_config: DeviceConfig, device_config: DeviceConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig, load_config: LoadConfig,
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
...@@ -43,7 +43,7 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -43,7 +43,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
self.load_config = load_config self.load_config = load_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
...@@ -62,7 +62,7 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -62,7 +62,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
device_config, device_config,
cache_config, cache_config,
load_config, load_config,
vision_language_config, multimodal_config,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
def init_device(self) -> None: def init_device(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment