Unverified Commit bbf55c48 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[VLM] Refactor `MultiModalConfig` initialization and profiling (#7530)

parent 1ef13cf9
...@@ -7,8 +7,8 @@ import torch.distributed ...@@ -7,8 +7,8 @@ import torch.distributed
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, ParallelConfig, PromptAdapterConfig,
PromptAdapterConfig, SchedulerConfig) SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -132,7 +132,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -132,7 +132,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
...@@ -148,7 +147,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -148,7 +147,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
...@@ -173,7 +171,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -173,7 +171,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
cache_config, cache_config,
load_config=self.load_config, load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
prompt_adapter_config=self.prompt_adapter_config, prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
......
...@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type ...@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalInputs from vllm.multimodal import MultiModalInputs
...@@ -44,7 +44,6 @@ class EmbeddingModelRunner( ...@@ -44,7 +44,6 @@ class EmbeddingModelRunner(
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None, observability_config: Optional[ObservabilityConfig] = None,
): ):
super().__init__(model_config, super().__init__(model_config,
...@@ -57,7 +56,6 @@ class EmbeddingModelRunner( ...@@ -57,7 +56,6 @@ class EmbeddingModelRunner(
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config, prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
observability_config=observability_config) observability_config=observability_config)
@torch.inference_mode() @torch.inference_mode()
......
...@@ -10,8 +10,8 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, ...@@ -10,8 +10,8 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend, get_global_forced_attn_backend,
global_force_attn_backend) global_force_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
...@@ -82,7 +82,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -82,7 +82,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None, observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
...@@ -90,7 +89,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -90,7 +89,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
''' '''
EncoderDecoderModelRunner constructor. EncoderDecoderModelRunner constructor.
`lora_config`, `multimodal_config`, and prompt_adapter_config are `lora_config` and `prompt_adapter_config` are
unused (since these features are not yet supported for encoder/decoder unused (since these features are not yet supported for encoder/decoder
models) but these arguments are present here for compatibility with models) but these arguments are present here for compatibility with
the base-class constructor. the base-class constructor.
...@@ -273,14 +272,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -273,14 +272,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = [] seqs: List[SequenceGroupMetadata] = []
model_config = self.model_config max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
mm_config = self.multimodal_config self.model_config)
input_registry = self.input_registry
mm_registry = self.mm_registry
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
if max_mm_tokens > 0: if max_mm_tokens > 0:
raise NotImplementedError( raise NotImplementedError(
"Multi-modal encoder-decoder models are not supported yet") "Multi-modal encoder-decoder models are not supported yet")
...@@ -291,8 +284,10 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -291,8 +284,10 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
seq_data, _ = input_registry \ seq_data, _ = self.input_registry \
.dummy_data_for_profiling(model_config, seq_len, mm_registry) .dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, ( assert len(seq_data.prompt_token_ids) >= seq_len, (
......
...@@ -27,8 +27,8 @@ except ImportError: ...@@ -27,8 +27,8 @@ except ImportError:
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
...@@ -804,7 +804,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -804,7 +804,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None, observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
...@@ -819,7 +818,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -819,7 +818,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
self.observability_config = observability_config self.observability_config = observability_config
...@@ -866,6 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -866,6 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry \ self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config) .create_input_mapper(model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
# Lazy initialization # Lazy initialization
self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
...@@ -893,7 +892,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -893,7 +892,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
device_config=self.device_config, device_config=self.device_config,
load_config=self.load_config, load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
cache_config=self.cache_config) cache_config=self.cache_config)
...@@ -1056,14 +1054,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1056,14 +1054,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# To exercise the worst scenario for GPU memory consumption, # To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number # the number of seqs (batch_size) is chosen to maximize the number
# of images processed. # of images processed.
model_config = self.model_config
mm_config = self.multimodal_config
input_registry = self.input_registry max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
mm_registry = self.mm_registry self.model_config)
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
if max_mm_tokens > 0: if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs, max_num_seqs = min(max_num_seqs,
...@@ -1082,8 +1075,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1082,8 +1075,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
seq_data, dummy_multi_modal_data = input_registry \ seq_data, dummy_multi_modal_data = self.input_registry \
.dummy_data_for_profiling(model_config, seq_len, mm_registry) .dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
......
...@@ -11,7 +11,7 @@ import torch_xla.runtime as xr ...@@ -11,7 +11,7 @@ import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
MultiModalConfig, ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -89,7 +89,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -89,7 +89,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
device_config: DeviceConfig, device_config: DeviceConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig, load_config: LoadConfig,
multimodal_config: Optional[MultiModalConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
...@@ -98,7 +97,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -98,7 +97,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
self.load_config = load_config self.load_config = load_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
...@@ -142,7 +140,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -142,7 +140,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
cache_config=self.cache_config, cache_config=self.cache_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
multimodal_config=self.multimodal_config,
lora_config=None, lora_config=None,
) )
model = model.eval() model = model.eval()
......
...@@ -7,7 +7,7 @@ import torch_xla.runtime as xr ...@@ -7,7 +7,7 @@ import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
MultiModalConfig, ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -31,7 +31,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -31,7 +31,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
device_config: DeviceConfig, device_config: DeviceConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig, load_config: LoadConfig,
multimodal_config: Optional[MultiModalConfig],
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
...@@ -44,7 +43,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -44,7 +43,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
self.load_config = load_config self.load_config = load_config
self.multimodal_config = multimodal_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
...@@ -64,7 +62,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -64,7 +62,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
device_config, device_config,
cache_config, cache_config,
load_config, load_config,
multimodal_config,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
def init_device(self) -> None: def init_device(self) -> None:
......
...@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario( ...@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
raise NotImplementedError( raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.multimodal_config is not None: if enc_dec_mr.model_config.multimodal_config is not None:
raise NotImplementedError( raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
......
...@@ -7,8 +7,8 @@ import torch ...@@ -7,8 +7,8 @@ import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
...@@ -46,7 +46,6 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -46,7 +46,6 @@ class Worker(LocalOrDistributedWorkerBase):
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None, speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
...@@ -73,7 +72,6 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -73,7 +72,6 @@ class Worker(LocalOrDistributedWorkerBase):
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
self.multimodal_config = multimodal_config
self.observability_config = observability_config self.observability_config = observability_config
# Return hidden states from target model if the draft model is an # Return hidden states from target model if the draft model is an
...@@ -103,7 +101,6 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -103,7 +101,6 @@ class Worker(LocalOrDistributedWorkerBase):
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config, prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
observability_config=observability_config, observability_config=observability_config,
**speculative_args, **speculative_args,
) )
......
...@@ -125,6 +125,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -125,6 +125,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry \ self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config) .create_input_mapper(model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
...@@ -166,14 +167,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -166,14 +167,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# To exercise the worst scenario for GPU memory consumption, # To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number # the number of seqs (batch_size) is chosen to maximize the number
# of images processed. # of images processed.
model_config = self.model_config max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
mm_config = self.multimodal_config self.model_config)
input_registry = self.input_registry
mm_registry = self.mm_registry
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
if max_mm_tokens > 0: if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs, max_num_seqs = min(max_num_seqs,
...@@ -190,8 +185,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -190,8 +185,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
seq_data, dummy_multi_modal_data = input_registry \ seq_data, dummy_multi_modal_data = self.input_registry \
.dummy_data_for_profiling(model_config, seq_len, mm_registry) .dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment