Unverified Commit bbf55c48 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[VLM] Refactor `MultiModalConfig` initialization and profiling (#7530)

parent 1ef13cf9
......@@ -7,8 +7,8 @@ import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
......@@ -132,7 +132,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
......@@ -148,7 +147,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
......@@ -173,7 +171,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=is_driver_worker)
......
......@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalInputs
......@@ -44,7 +44,6 @@ class EmbeddingModelRunner(
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
super().__init__(model_config,
......@@ -57,7 +56,6 @@ class EmbeddingModelRunner(
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
observability_config=observability_config)
@torch.inference_mode()
......
......@@ -10,8 +10,8 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend,
global_force_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
......@@ -82,7 +82,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
......@@ -90,7 +89,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
'''
EncoderDecoderModelRunner constructor.
`lora_config`, `multimodal_config`, and prompt_adapter_config are
`lora_config` and `prompt_adapter_config` are
unused (since these features are not yet supported for encoder/decoder
models) but these arguments are present here for compatibility with
the base-class constructor.
......@@ -273,14 +272,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
model_config = self.model_config
mm_config = self.multimodal_config
input_registry = self.input_registry
mm_registry = self.mm_registry
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
raise NotImplementedError(
"Multi-modal encoder-decoder models are not supported yet")
......@@ -291,8 +284,10 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, _ = input_registry \
.dummy_data_for_profiling(model_config, seq_len, mm_registry)
seq_data, _ = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
......
......@@ -27,8 +27,8 @@ except ImportError:
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY, InputRegistry
......@@ -804,7 +804,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
......@@ -819,7 +818,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.load_config = load_config
self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.return_hidden_states = return_hidden_states
self.observability_config = observability_config
......@@ -866,6 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
# Lazy initialization
self.model: nn.Module # Set after load_model
......@@ -893,7 +892,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
......@@ -1056,14 +1054,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
model_config = self.model_config
mm_config = self.multimodal_config
input_registry = self.input_registry
mm_registry = self.mm_registry
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
......@@ -1082,8 +1075,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, dummy_multi_modal_data = input_registry \
.dummy_data_for_profiling(model_config, seq_len, mm_registry)
seq_data, dummy_multi_modal_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),
......
......@@ -11,7 +11,7 @@ import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
MultiModalConfig, ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -89,7 +89,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
multimodal_config: Optional[MultiModalConfig] = None,
is_driver_worker: bool = False,
):
self.model_config = model_config
......@@ -98,7 +97,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size
......@@ -142,7 +140,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
parallel_config=self.parallel_config,
cache_config=self.cache_config,
scheduler_config=self.scheduler_config,
multimodal_config=self.multimodal_config,
lora_config=None,
)
model = model.eval()
......
......@@ -7,7 +7,7 @@ import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
MultiModalConfig, ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
......@@ -31,7 +31,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
multimodal_config: Optional[MultiModalConfig],
local_rank: int,
rank: int,
distributed_init_method: str,
......@@ -44,7 +43,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.multimodal_config = multimodal_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
......@@ -64,7 +62,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
device_config,
cache_config,
load_config,
multimodal_config,
is_driver_worker=is_driver_worker)
def init_device(self) -> None:
......
......@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.multimodal_config is not None:
if enc_dec_mr.model_config.multimodal_config is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
......
......@@ -7,8 +7,8 @@ import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
......@@ -46,7 +46,6 @@ class Worker(LocalOrDistributedWorkerBase):
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
......@@ -73,7 +72,6 @@ class Worker(LocalOrDistributedWorkerBase):
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.multimodal_config = multimodal_config
self.observability_config = observability_config
# Return hidden states from target model if the draft model is an
......@@ -103,7 +101,6 @@ class Worker(LocalOrDistributedWorkerBase):
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
observability_config=observability_config,
**speculative_args,
)
......
......@@ -125,6 +125,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
......@@ -166,14 +167,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
model_config = self.model_config
mm_config = self.multimodal_config
input_registry = self.input_registry
mm_registry = self.mm_registry
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
......@@ -190,8 +185,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data, dummy_multi_modal_data = input_registry \
.dummy_data_for_profiling(model_config, seq_len, mm_registry)
seq_data, dummy_multi_modal_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment