Unverified Commit 82551ad6 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Don't use cache during multi-modal profiling (#14336)

parent caac5c2e
...@@ -331,7 +331,9 @@ class InputRegistry: ...@@ -331,7 +331,9 @@ class InputRegistry:
if mm_registry.has_processor(model_config): if mm_registry.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
processor = mm_registry.create_processor(model_config, tokenizer) processor = mm_registry.create_processor(model_config,
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data( dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data) seq_len, is_encoder_data=is_encoder_data)
......
...@@ -257,7 +257,9 @@ class MultiModalRegistry: ...@@ -257,7 +257,9 @@ class MultiModalRegistry:
""" """
if self.has_processor(model_config): if self.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
processor = self.create_processor(model_config, tokenizer) processor = self.create_processor(model_config,
tokenizer,
disable_cache=True)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config) mm_limits = self.get_mm_limits_per_prompt(model_config)
return processor.info.get_mm_max_tokens_per_item( return processor.info.get_mm_max_tokens_per_item(
...@@ -372,7 +374,9 @@ class MultiModalRegistry: ...@@ -372,7 +374,9 @@ class MultiModalRegistry:
""" """
if self.has_processor(model_config): if self.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
processor = self.create_processor(model_config, tokenizer) processor = self.create_processor(model_config,
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits() return profiler.get_mm_limits()
...@@ -433,6 +437,8 @@ class MultiModalRegistry: ...@@ -433,6 +437,8 @@ class MultiModalRegistry:
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
*,
disable_cache: Optional[bool] = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]: ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
""" """
Create a multi-modal processor for a specific model and tokenizer. Create a multi-modal processor for a specific model and tokenizer.
...@@ -440,11 +446,13 @@ class MultiModalRegistry: ...@@ -440,11 +446,13 @@ class MultiModalRegistry:
See also: See also:
:ref:`mm-processing` :ref:`mm-processing`
""" """
if disable_cache is None:
disable_cache = model_config.disable_mm_preprocessor_cache
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls] factories = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer) ctx = InputProcessingContext(model_config, tokenizer)
cache = (None if model_config.disable_mm_preprocessor_cache else cache = None if disable_cache else self._processing_cache
self._processing_cache)
return factories.build_processor(ctx, cache=cache) return factories.build_processor(ctx, cache=cache)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment