"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "288cc6c234d0509e2caed574729b2829ed9921a1"
Unverified Commit a94eee44 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix mm_limits access for merged multi-modal processor (#12252)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f2e9f2a3
...@@ -106,7 +106,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -106,7 +106,7 @@ class MultiModalProfiler(Generic[_I]):
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs return self.processor.dummy_inputs
def _get_mm_limits(self) -> Mapping[str, int]: def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config() mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt mm_limit_per_prompt = mm_config.limit_per_prompt
...@@ -146,7 +146,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -146,7 +146,7 @@ class MultiModalProfiler(Generic[_I]):
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
mm_counts = self._get_mm_limits() mm_counts = self.get_mm_limits()
info = self.processing_info info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len) mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
......
...@@ -17,7 +17,7 @@ from .image import ImagePlugin ...@@ -17,7 +17,7 @@ from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache) ProcessingCache)
from .profiling import BaseDummyInputsBuilder from .profiling import BaseDummyInputsBuilder, MultiModalProfiler
from .utils import cached_get_tokenizer from .utils import cached_get_tokenizer
from .video import VideoPlugin from .video import VideoPlugin
...@@ -282,13 +282,13 @@ class MultiModalRegistry: ...@@ -282,13 +282,13 @@ class MultiModalRegistry:
This is currently directly used only in V1 for profiling the memory This is currently directly used only in V1 for profiling the memory
usage of a model. usage of a model.
""" """
limits_per_plugin = self._limits_by_model[model_config] mm_limits = self.get_mm_limits_per_prompt(model_config)
return { return {
key: max_tokens_per_mm_item key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items() self.get_max_tokens_per_item_by_modality(model_config).items()
if limits_per_plugin[key] > 0 if mm_limits[key] > 0
} }
def get_max_tokens_by_modality( def get_max_tokens_by_modality(
...@@ -304,10 +304,10 @@ class MultiModalRegistry: ...@@ -304,10 +304,10 @@ class MultiModalRegistry:
Note: Note:
This should be called after :meth:`init_mm_limits_per_prompt`. This should be called after :meth:`init_mm_limits_per_prompt`.
""" """
limits_per_plugin = self._limits_by_model[model_config] mm_limits = self.get_mm_limits_per_prompt(model_config)
return { return {
key: limits_per_plugin[key] * max_tokens_per_mm_item key: mm_limits[key] * max_tokens_per_mm_item
for key, max_tokens_per_mm_item in for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items() self.get_max_tokens_per_item_by_modality(model_config).items()
} }
...@@ -371,6 +371,15 @@ class MultiModalRegistry: ...@@ -371,6 +371,15 @@ class MultiModalRegistry:
Note: Note:
This should be called after :meth:`init_mm_limits_per_prompt`. This should be called after :meth:`init_mm_limits_per_prompt`.
""" """
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
processor = self.create_processor(model_config, tokenizer)
profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
return self._limits_by_model[model_config] return self._limits_by_model[model_config]
def register_processor( def register_processor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment