Unverified Commit 0c6e40bb authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Simplify code for MM budget (#23310)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 2e2000f3
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -188,35 +188,47 @@ def compute_encoder_budget( ...@@ -188,35 +188,47 @@ def compute_encoder_budget(
- Space budget for encoder cache size, in unit of number of tokens - Space budget for encoder cache size, in unit of number of tokens
in the input sequence. in the input sequence.
""" """
if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
if not mm_registry.supports_multimodal_inputs(model_config): return compute_mm_encoder_budget(
return 0, 0 scheduler_config,
max_tokens_by_modality,
)
# TODO: handle encoder-decoder models once we support them. return compute_text_encoder_budget(scheduler_config)
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(
model_config,
scheduler_config,
mm_registry,
)
return encoder_compute_budget, encoder_cache_size
def compute_text_encoder_budget(
scheduler_config: "SchedulerConfig") -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a text-only model.
def _compute_encoder_budget_multimodal( Args:
model_config: "ModelConfig", scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
# Currently text-only encoder-decoder models are not supported
return 0, 0
def compute_mm_encoder_budget(
scheduler_config: "SchedulerConfig", scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry, max_tokens_by_modality: Mapping[str, int],
) -> tuple[int, int]: ) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler """Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model. configurations for a multimodal model.
Args: Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration. scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost. max_tokens_by_modality: The maximum number of tokens for each
non-text modality.
Returns: Returns:
- Compute budget for encoder execution, in unit of number of tokens - Compute budget for encoder execution, in unit of number of tokens
...@@ -225,18 +237,14 @@ def _compute_encoder_budget_multimodal( ...@@ -225,18 +237,14 @@ def _compute_encoder_budget_multimodal(
in the input sequence. in the input sequence.
""" """
max_tokens_by_modality_dict = mm_registry \ if not max_tokens_by_modality:
.get_max_tokens_per_item_by_nonzero_modality(model_config)
if not max_tokens_by_modality_dict:
logger.warning( logger.warning(
"All non-text modalities supported by the model have been " "All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will " "explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized.") "not be initialized.")
return 0, 0 return 0, 0
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), max_tokens_per_mm_item = max(max_tokens_by_modality.values())
key=lambda item: item[1])
if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item
> scheduler_config.max_num_batched_tokens): > scheduler_config.max_num_batched_tokens):
......
...@@ -341,10 +341,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -341,10 +341,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model_config, self.model_config,
self.scheduler_config, self.scheduler_config,
self.mm_registry, self.mm_registry,
max_model_len=self.max_model_len, ) if self.supports_mm_inputs else None)
max_num_reqs=self.max_num_reqs,
) if self.supports_mm_inputs \
else None)
self.reorder_batch_threshold: Optional[int] = None self.reorder_batch_threshold: Optional[int] = None
...@@ -669,7 +666,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -669,7 +666,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_budget = self.mm_budget mm_budget = self.mm_budget
assert mm_budget is not None assert mm_budget is not None
dummy_modality, _ = mm_budget.get_modality_with_max_tokens() dummy_modality = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs) return self._get_mm_dummy_batch(dummy_modality, num_seqs)
...@@ -2595,14 +2592,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2595,14 +2592,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
# it supports multiple. # it supports multiple.
( dummy_modality = mm_budget.get_modality_with_max_tokens()
dummy_modality, max_mm_items_per_batch = mm_budget \
max_tokens, .max_items_per_batch_by_modality[dummy_modality]
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
logger.info( logger.info(
"Encoder cache will be initialized with a budget of " "Encoder cache will be initialized with a budget of "
......
...@@ -292,8 +292,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -292,8 +292,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model_config, self.model_config,
self.scheduler_config, self.scheduler_config,
self.mm_registry, self.mm_registry,
max_model_len=self.max_model_len,
max_num_reqs=self.max_num_reqs,
) if self.supports_mm_inputs else None) ) if self.supports_mm_inputs else None)
if not self.use_spmd: if not self.use_spmd:
...@@ -1545,14 +1543,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1545,14 +1543,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
# it supports multiple. # it supports multiple.
( dummy_modality = mm_budget.get_modality_with_max_tokens()
dummy_modality, max_mm_items_per_batch = mm_budget \
max_tokens, .max_items_per_batch_by_modality[dummy_modality]
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
logger.info( logger.info(
"Encoder cache will be initialized with a budget of " "Encoder cache will be initialized with a budget of "
......
...@@ -12,7 +12,7 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings ...@@ -12,7 +12,7 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.registry import MultiModalRegistry from vllm.multimodal.registry import MultiModalRegistry
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -27,9 +27,6 @@ class MultiModalBudget: ...@@ -27,9 +27,6 @@ class MultiModalBudget:
model_config: ModelConfig, model_config: ModelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
mm_registry: MultiModalRegistry, mm_registry: MultiModalRegistry,
*,
max_model_len: int,
max_num_reqs: int,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -37,25 +34,25 @@ class MultiModalBudget: ...@@ -37,25 +34,25 @@ class MultiModalBudget:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.mm_registry = mm_registry self.mm_registry = mm_registry
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( self.max_model_len = model_config.max_model_len
model_config=model_config, self.max_num_reqs = scheduler_config.max_num_seqs
scheduler_config=scheduler_config,
mm_registry=mm_registry, self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
max_tokens_by_modality = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
max_tokens_by_modality,
) )
self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_compute_budget = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
self.max_model_len = max_model_len
self.max_num_reqs = max_num_reqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
max_items_per_prompt_by_modality = dict[str, int]() max_items_per_prompt_by_modality = dict[str, int]()
max_items_per_batch_by_modality = dict[str, int]() max_items_per_batch_by_modality = dict[str, int]()
max_tokens_by_modality = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
for modality, max_tokens in max_tokens_by_modality.items(): for modality, max_tokens in max_tokens_by_modality.items():
( (
max_items_per_prompt, max_items_per_prompt,
...@@ -69,15 +66,14 @@ class MultiModalBudget: ...@@ -69,15 +66,14 @@ class MultiModalBudget:
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
def get_modality_with_max_tokens(self) -> tuple[str, int]: def get_modality_with_max_tokens(self) -> str:
max_tokens_by_modality = self.max_tokens_by_modality max_tokens_by_modality = self.max_tokens_by_modality
modality, max_tokens = max(max_tokens_by_modality.items(), modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
key=lambda item: item[1])
return modality, max_tokens return modality
def get_encoder_budget(self) -> int: def get_encoder_budget(self) -> int:
return min(self.max_num_encoder_input_tokens, self.encoder_cache_size) return min(self.encoder_compute_budget, self.encoder_cache_size)
def get_max_items( def get_max_items(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment