"docs/vscode:/vscode.git/clone" did not exist on "df3dcdf49dccfa4914d825fa08b74de8ae050e1e"
Unverified Commit 13ac9cab authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Avoid direct access of global `mm_registry` in `compute_encoder_budget` (#15621)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 66aa4c0b
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request from vllm.v1.request import Request
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -67,6 +67,7 @@ class EncoderCacheManager: ...@@ -67,6 +67,7 @@ class EncoderCacheManager:
def compute_encoder_budget( def compute_encoder_budget(
model_config: "ModelConfig", model_config: "ModelConfig",
scheduler_config: "SchedulerConfig", scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]: ) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler """Compute the encoder cache budget based on the model and scheduler
configurations. configurations.
...@@ -74,6 +75,7 @@ def compute_encoder_budget( ...@@ -74,6 +75,7 @@ def compute_encoder_budget(
Args: Args:
model_config: Model configuration. model_config: Model configuration.
scheduler_config: Scheduler configuration. scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns: Returns:
- Compute budget for encoder execution, in unit of number of tokens - Compute budget for encoder execution, in unit of number of tokens
...@@ -89,7 +91,11 @@ def compute_encoder_budget( ...@@ -89,7 +91,11 @@ def compute_encoder_budget(
( (
encoder_compute_budget, encoder_compute_budget,
encoder_cache_size, encoder_cache_size,
) = _compute_encoder_budget_multimodal(model_config, scheduler_config) ) = _compute_encoder_budget_multimodal(
model_config,
scheduler_config,
mm_registry,
)
return encoder_compute_budget, encoder_cache_size return encoder_compute_budget, encoder_cache_size
...@@ -97,6 +103,7 @@ def compute_encoder_budget( ...@@ -97,6 +103,7 @@ def compute_encoder_budget(
def _compute_encoder_budget_multimodal( def _compute_encoder_budget_multimodal(
model_config: "ModelConfig", model_config: "ModelConfig",
scheduler_config: "SchedulerConfig", scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]: ) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler """Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model. configurations for a multimodal model.
...@@ -104,6 +111,7 @@ def _compute_encoder_budget_multimodal( ...@@ -104,6 +111,7 @@ def _compute_encoder_budget_multimodal(
Args: Args:
model_config: Model configuration. model_config: Model configuration.
scheduler_config: Scheduler configuration. scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns: Returns:
- Compute budget for encoder execution, in unit of number of tokens - Compute budget for encoder execution, in unit of number of tokens
...@@ -112,8 +120,8 @@ def _compute_encoder_budget_multimodal( ...@@ -112,8 +120,8 @@ def _compute_encoder_budget_multimodal(
in the input sequence. in the input sequence.
""" """
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 max_tokens_by_modality_dict = mm_registry \
model_config) .get_max_tokens_per_item_by_nonzero_modality(model_config)
if not max_tokens_by_modality_dict: if not max_tokens_by_modality_dict:
logger.warning( logger.warning(
......
...@@ -10,6 +10,7 @@ from typing import Optional, Union ...@@ -10,6 +10,7 @@ from typing import Optional, Union
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget) compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
...@@ -38,6 +39,7 @@ class Scheduler(SchedulerInterface): ...@@ -38,6 +39,7 @@ class Scheduler(SchedulerInterface):
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
log_stats: bool, log_stats: bool,
structured_output_manager: StructuredOutputManager, structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
...@@ -93,6 +95,7 @@ class Scheduler(SchedulerInterface): ...@@ -93,6 +95,7 @@ class Scheduler(SchedulerInterface):
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config, model_config=model_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
mm_registry=mm_registry,
) )
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
......
...@@ -137,6 +137,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -137,6 +137,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config, model_config=model_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
mm_registry=self.mm_registry,
) )
self.max_num_encoder_input_tokens = encoder_compute_budget self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
...@@ -1439,9 +1440,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1439,9 +1440,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
# it supports multiple. # it supports multiple.
max_tokens_by_modality_dict = ( max_tokens_by_modality_dict = self.mm_registry \
MULTIMODAL_REGISTRY. .get_max_tokens_per_item_by_nonzero_modality(self.model_config)
get_max_tokens_per_item_by_nonzero_modality(self.model_config))
dummy_data_modality, max_tokens_per_mm_item = max( dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1]) max_tokens_by_modality_dict.items(), key=lambda item: item[1])
......
...@@ -109,6 +109,7 @@ class TPUModelRunner: ...@@ -109,6 +109,7 @@ class TPUModelRunner:
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config, model_config=model_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
mm_registry=self.mm_registry,
) )
self.max_num_encoder_input_tokens = encoder_compute_budget self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment