Unverified Commit 49252cf5 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[MM] Allow skipping memory profiling for multimodal models. (#22950)


Signed-off-by: default avatarRoger Wang <hey@rogerw.me>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 3e6dd400
...@@ -388,6 +388,10 @@ class ModelConfig: ...@@ -388,6 +388,10 @@ class ModelConfig:
interleave_mm_strings: bool = False interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using """Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string. Defaults to False.""" --chat-template-content-format=string. Defaults to False."""
skip_mm_profiling: bool = False
"""When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
"""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities. """Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set For example, to set num_frames for video, set
...@@ -837,7 +841,8 @@ class ModelConfig: ...@@ -837,7 +841,8 @@ class ModelConfig:
media_io_kwargs=self.media_io_kwargs, media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_gb=self.mm_processor_cache_gb,
interleave_mm_strings=self.interleave_mm_strings) interleave_mm_strings=self.interleave_mm_strings,
skip_mm_profiling=self.skip_mm_profiling)
return None return None
...@@ -2511,6 +2516,16 @@ class MultiModalConfig: ...@@ -2511,6 +2516,16 @@ class MultiModalConfig:
Enable fully interleaved support for multimodal prompts. Enable fully interleaved support for multimodal prompts.
""" """
skip_mm_profiling: bool = False
"""
When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
This reduces engine startup time but shifts the responsibility to users for
estimating the peak memory usage of the activation of multimodal encoder and
embedding cache.
"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
......
...@@ -350,6 +350,7 @@ class EngineArgs: ...@@ -350,6 +350,7 @@ class EngineArgs:
MultiModalConfig.mm_processor_kwargs MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields # LoRA fields
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled enable_lora_bias: bool = LoRAConfig.bias_enabled
...@@ -716,6 +717,8 @@ class EngineArgs: ...@@ -716,6 +717,8 @@ class EngineArgs:
multimodal_group.add_argument( multimodal_group.add_argument(
"--interleave-mm-strings", "--interleave-mm-strings",
**multimodal_kwargs["interleave_mm_strings"]) **multimodal_kwargs["interleave_mm_strings"])
multimodal_group.add_argument("--skip-mm-profiling",
**multimodal_kwargs["skip_mm_profiling"])
# LoRA related configs # LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig) lora_kwargs = get_kwargs(LoRAConfig)
...@@ -918,6 +921,7 @@ class EngineArgs: ...@@ -918,6 +921,7 @@ class EngineArgs:
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
interleave_mm_strings=self.interleave_mm_strings, interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs, media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling,
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
......
...@@ -2479,50 +2479,56 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2479,50 +2479,56 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def profile_run(self) -> None: def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs: if self.supports_mm_inputs:
mm_budget = self.mm_budget if self.model_config.multimodal_config.skip_mm_profiling:
assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
(
dummy_modality,
max_tokens,
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
logger.info( logger.info(
"Encoder cache will be initialized with a budget of " "Skipping memory profiling for multimodal encoder and "
"%s tokens, and profiled with %s %s items of the maximum " "encoder cache.")
"feature size.", else:
encoder_budget, mm_budget = self.mm_budget
max_mm_items_per_batch, assert mm_budget is not None
dummy_modality,
) # TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
(
dummy_modality,
max_tokens,
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
logger.info(
"Encoder cache will be initialized with a budget of "
"%s tokens, and profiled with %s %s items of the "
"maximum feature size.",
encoder_budget,
max_mm_items_per_batch,
dummy_modality,
)
# Create dummy batch of multimodal inputs. # Create dummy batch of multimodal inputs.
batched_dummy_mm_inputs = self._get_mm_dummy_batch( batched_dummy_mm_inputs = self._get_mm_dummy_batch(
dummy_modality, dummy_modality,
max_mm_items_per_batch, max_mm_items_per_batch,
) )
# Run multimodal encoder. # Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings( dummy_encoder_outputs = \
**batched_dummy_mm_inputs) self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
dummy_encoder_outputs, dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch, expected_num_items=max_mm_items_per_batch,
) )
# Cache the dummy encoder outputs. # Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict( self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs)) enumerate(dummy_encoder_outputs))
# Add `is_profile` here to pre-allocate communication buffers # Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states \ hidden_states, last_hidden_states \
......
...@@ -1529,60 +1529,66 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1529,60 +1529,66 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> None: ) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs: if self.supports_mm_inputs:
mm_budget = self.mm_budget if self.model_config.multimodal_config.skip_mm_profiling:
assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
(
dummy_modality,
max_tokens,
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
logger.info( logger.info(
"Encoder cache will be initialized with a budget of " "Skipping memory profiling for multimodal encoder and "
"%s tokens, and profiled with %s %s items of the maximum " "encoder cache.")
"feature size.", else:
encoder_budget, mm_budget = self.mm_budget
max_mm_items_per_batch, assert mm_budget is not None
dummy_modality,
) # TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# Create dummy batch of multimodal inputs. # NOTE: Currently model is profiled with a single non-text
batched_dummy_mm_inputs = self._get_mm_dummy_batch( # modality with the max possible input tokens even when
dummy_modality, # it supports multiple.
max_mm_items_per_batch, (
) dummy_modality,
max_tokens,
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
logger.info(
"Encoder cache will be initialized with a budget of "
"%s tokens, and profiled with %s %s items of the "
"maximum feature size.",
encoder_budget,
max_mm_items_per_batch,
dummy_modality,
)
# Run multimodal encoder. # Create dummy batch of multimodal inputs.
# Isolate encoder graph from post-processing to minimize batched_dummy_mm_inputs = self._get_mm_dummy_batch(
# impact of recompilation until it's fixed. dummy_modality,
start = time.perf_counter() max_mm_items_per_batch,
xm.mark_step() )
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
xm.mark_step()
xm.wait_device_ops()
end = time.perf_counter()
logger.info(
"Multimodal Encoder profiling finished in in %.2f [secs].",
end - start)
sanity_check_mm_encoder_outputs( # Run multimodal encoder.
dummy_encoder_outputs, # Isolate encoder graph from post-processing to minimize
expected_num_items=max_mm_items_per_batch, # impact of recompilation until it's fixed.
) start = time.perf_counter()
xm.mark_step()
dummy_encoder_outputs = \
self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
xm.mark_step()
xm.wait_device_ops()
end = time.perf_counter()
logger.info(
"Multimodal Encoder profiling finished in %.2f [secs].",
end - start)
sanity_check_mm_encoder_outputs(
dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch,
)
# Cache the dummy encoder outputs. # Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict( self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs)) enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape. # Trigger compilation for general shape.
self._dummy_run(num_tokens, self.num_reqs_max_model_len, self._dummy_run(num_tokens, self.num_reqs_max_model_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment