Unverified Commit d9e98f42 authored by xwjiang2010's avatar xwjiang2010 Committed by GitHub
Browse files

[vlm] Remove vision language config. (#6089)


Signed-off-by: default avatarXiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 3c6325f0
...@@ -7,8 +7,8 @@ import torch ...@@ -7,8 +7,8 @@ import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SpeculativeConfig, VisionLanguageConfig) SchedulerConfig, SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
...@@ -43,7 +43,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -43,7 +43,7 @@ class Worker(LocalOrDistributedWorkerBase):
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None, speculative_config: Optional[SpeculativeConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
...@@ -66,10 +66,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -66,10 +66,7 @@ class Worker(LocalOrDistributedWorkerBase):
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
if self.vision_language_config:
assert not self.lora_config, (
"To be tested: vision language model with LoRA settings.")
# Return hidden states from target model if the draft model is an # Return hidden states from target model if the draft model is an
# mlp_speculator # mlp_speculator
...@@ -94,7 +91,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -94,7 +91,7 @@ class Worker(LocalOrDistributedWorkerBase):
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config, multimodal_config=multimodal_config,
**speculative_args, **speculative_args,
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
......
...@@ -7,12 +7,13 @@ import torch.nn as nn ...@@ -7,12 +7,13 @@ import torch.nn as nn
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
VisionLanguageConfig) SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_vision
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs) MultiModalInputs)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -85,7 +86,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -85,7 +86,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
*args, *args,
...@@ -97,7 +98,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -97,7 +98,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config self.load_config = load_config
self.cache_config = cache_config self.cache_config = cache_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
...@@ -134,7 +135,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -134,7 +135,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
device_config=self.device_config, device_config=self.device_config,
load_config=self.load_config, load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
cache_config=self.cache_config, cache_config=self.cache_config,
...@@ -165,12 +166,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -165,12 +166,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# the number of seqs (batch_size) is chosen to maximize the number # the number of seqs (batch_size) is chosen to maximize the number
# of images processed. # of images processed.
model_config = self.model_config model_config = self.model_config
vlm_config = self.vision_language_config
if vlm_config: if supports_vision(self.model):
max_num_seqs = min( # TODO: properly inject these numbers from MultiModalRegistry.
max_num_seqs, # Right now, just use an overly conservative number.
int(max_num_batched_tokens / vlm_config.image_feature_size)) max_num_seqs = max(
1,
min(
max_num_seqs,
int(max_num_batched_tokens /
MULTIMODAL_REGISTRY.get_num_input_tokens())))
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
......
...@@ -9,8 +9,8 @@ import torch ...@@ -9,8 +9,8 @@ import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, MultiModalConfig, ParallelConfig,
SpeculativeConfig, VisionLanguageConfig) SchedulerConfig, SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -45,7 +45,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): ...@@ -45,7 +45,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None, speculative_config: Optional[SpeculativeConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
...@@ -66,10 +66,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): ...@@ -66,10 +66,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
if self.vision_language_config:
assert not self.lora_config, (
"To be tested: vision language model with LoRA settings.")
self.model_runner = XPUModelRunner( # type: ignore self.model_runner = XPUModelRunner( # type: ignore
model_config, model_config,
...@@ -81,7 +78,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): ...@@ -81,7 +78,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config, multimodal_config=multimodal_config,
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment