Unverified Commit 4c5f6321 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Simplify max tokens in multimodal registry (#27500)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b8535403
...@@ -355,7 +355,11 @@ class MultiModalProfiler(Generic[_I]): ...@@ -355,7 +355,11 @@ class MultiModalProfiler(Generic[_I]):
mm_counts=mm_counts, mm_counts=mm_counts,
) )
if max_tokens_per_item is not None: if max_tokens_per_item is not None:
return max_tokens_per_item return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
...@@ -375,5 +379,4 @@ class MultiModalProfiler(Generic[_I]): ...@@ -375,5 +379,4 @@ class MultiModalProfiler(Generic[_I]):
This is important to take into account when profiling and This is important to take into account when profiling and
initializing the encoder cache size. initializing the encoder cache size.
""" """
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False) return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
...@@ -152,6 +152,7 @@ class MultiModalRegistry: ...@@ -152,6 +152,7 @@ class MultiModalRegistry:
model_config: "ModelConfig", model_config: "ModelConfig",
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
""" """
Get the maximum number of tokens per data item from each modality based Get the maximum number of tokens per data item from each modality based
...@@ -164,40 +165,15 @@ class MultiModalRegistry: ...@@ -164,40 +165,15 @@ class MultiModalRegistry:
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) profiler_limits = (
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)
return profiler.get_mm_max_contiguous_tokens( return profiler.get_mm_max_contiguous_tokens(
seq_len, seq_len,
{modality: 1 for modality, limit in mm_limits.items() if limit > 0}, {modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)
def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.
Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
) )
return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
if mm_limits[key] > 0
}
def get_mm_limits_per_prompt( def get_mm_limits_per_prompt(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
...@@ -369,7 +345,7 @@ class MultiModalRegistry: ...@@ -369,7 +345,7 @@ class MultiModalRegistry:
""" """
if not model_config.is_encoder_decoder: if not model_config.is_encoder_decoder:
return 0 return 0
max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config) max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
if not max_tokens: if not max_tokens:
# TODO - this function assumes encoder-decoder models are # TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more # multimodal. This will need to change when adding support for more
......
...@@ -264,8 +264,8 @@ def compute_encoder_budget( ...@@ -264,8 +264,8 @@ def compute_encoder_budget(
from the input sequence. from the input sequence.
""" """
if mm_registry.supports_multimodal_inputs(model_config): if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = ( max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config) model_config
) )
return compute_mm_encoder_budget( return compute_mm_encoder_budget(
......
...@@ -42,10 +42,10 @@ class MultiModalBudget: ...@@ -42,10 +42,10 @@ class MultiModalBudget:
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_by_modality = ( max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
mm_registry.get_max_tokens_per_item_by_nonzero_modality( model_config,
model_config, cache=cache cache=cache,
) profiler_limits=self.mm_limits,
) )
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment