Unverified Commit 4c5f6321 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Simplify max tokens in multimodal registry (#27500)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b8535403
......@@ -355,7 +355,11 @@ class MultiModalProfiler(Generic[_I]):
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return max_tokens_per_item
return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
......@@ -375,5 +379,4 @@ class MultiModalProfiler(Generic[_I]):
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
......@@ -152,6 +152,7 @@ class MultiModalRegistry:
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
......@@ -164,40 +165,15 @@ class MultiModalRegistry:
profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
profiler_limits = (
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)
return profiler.get_mm_max_contiguous_tokens(
seq_len,
{modality: 1 for modality, limit in mm_limits.items() if limit > 0},
)
def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.
Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)
return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
if mm_limits[key] > 0
}
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
......@@ -369,7 +345,7 @@ class MultiModalRegistry:
"""
if not model_config.is_encoder_decoder:
return 0
max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config)
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
if not max_tokens:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
......
......@@ -264,8 +264,8 @@ def compute_encoder_budget(
from the input sequence.
"""
if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config
)
return compute_mm_encoder_budget(
......
......@@ -42,10 +42,10 @@ class MultiModalBudget:
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_nonzero_modality(
model_config, cache=cache
)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
profiler_limits=self.mm_limits,
)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment