Unverified Commit 70755e81 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[V1][Core] Autotune encoder cache budget (#11895)


Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
parent edce722e
...@@ -1387,13 +1387,15 @@ class SchedulerConfig: ...@@ -1387,13 +1387,15 @@ class SchedulerConfig:
is_multimodal_model: bool = False is_multimodal_model: bool = False
# FIXME(woosuk & ywang96): Below are placeholder values. We need to # NOTE: The following multimodal encoder budget will be initialized to
# calculate the actual values from the configurations. # max_num_batched_tokens and overridden in case max multimodal embedding
# Multimodal encoder run compute budget, only used in V1 # size is larger.
max_num_encoder_input_tokens = 16384 # TODO (ywang96): Make these configurable.
# Multimodal encoder compute budget, only used in V1
max_num_encoder_input_tokens: int = field(default=None) # type: ignore
# Multimodal encoder cache size, only used in V1 # Multimodal encoder cache size, only used in V1
encoder_cache_size = 16384 encoder_cache_size: int = field(default=None) # type: ignore
# Whether to perform preemption by swapping or # Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows: # recomputation. If not specified, we determine the mode as follows:
...@@ -1467,6 +1469,9 @@ class SchedulerConfig: ...@@ -1467,6 +1469,9 @@ class SchedulerConfig:
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
) )
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens
if self.enable_chunked_prefill: if self.enable_chunked_prefill:
logger.info( logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.", "Chunked prefill is enabled with max_num_batched_tokens=%d.",
......
...@@ -252,11 +252,8 @@ class MultiModalRegistry: ...@@ -252,11 +252,8 @@ class MultiModalRegistry:
model_config: "ModelConfig", model_config: "ModelConfig",
) -> Mapping[str, int]: ) -> Mapping[str, int]:
""" """
Get the maximum number of tokens per data item from each modality Get the maximum number of tokens per data item from each modality based
for profiling the memory usage of a model. on underlying model configuration.
Note:
This is currently directly used only in V1.
""" """
if self.has_processor(model_config): if self.has_processor(model_config):
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
...@@ -272,6 +269,28 @@ class MultiModalRegistry: ...@@ -272,6 +269,28 @@ class MultiModalRegistry:
for key, plugin in self._plugins.items() for key, plugin in self._plugins.items()
} }
def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.
Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
limits_per_plugin = self._limits_by_model[model_config]
return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
if limits_per_plugin[key] > 0
}
def get_max_tokens_by_modality( def get_max_tokens_by_modality(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
......
from typing import Dict, List, Set, Tuple from typing import TYPE_CHECKING, Dict, List, Set, Tuple
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.v1.request import Request from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig
logger = init_logger(__name__)
class EncoderCacheManager: class EncoderCacheManager:
...@@ -46,3 +53,72 @@ class EncoderCacheManager: ...@@ -46,3 +53,72 @@ class EncoderCacheManager:
freed = self.freed freed = self.freed
self.freed = [] self.freed = []
return freed return freed
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> Tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
if not model_config.is_multimodal_model:
return 0, 0
# TODO: handle encoder-decoder models once we support them.
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(model_config, scheduler_config)
return encoder_compute_budget, encoder_cache_size
def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> Tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
model_config)
if not max_tokens_by_modality_dict:
logger.warning(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized.")
return 0, 0
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
max_tokens_per_mm_item)
encoder_cache_size = max(scheduler_config.encoder_cache_size,
max_tokens_per_mm_item)
return encoder_compute_budget, encoder_cache_size
...@@ -3,10 +3,11 @@ from dataclasses import dataclass ...@@ -3,10 +3,11 @@ from dataclasses import dataclass
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union) Tuple, Union)
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
...@@ -25,6 +26,7 @@ class Scheduler: ...@@ -25,6 +26,7 @@ class Scheduler:
def __init__( def __init__(
self, self,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
model_config: ModelConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
) -> None: ) -> None:
...@@ -69,16 +71,24 @@ class Scheduler: ...@@ -69,16 +71,24 @@ class Scheduler:
self.running_reqs_data: Dict[str, RunningRequestData] = {} self.running_reqs_data: Dict[str, RunningRequestData] = {}
# Encoder-related. # Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also # projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT). # has the Transformer architecture (e.g., ViT).
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501 self.max_num_encoder_input_tokens = encoder_compute_budget
# NOTE(woosuk): For the models without encoder (e.g., text-only models), # NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of # the encoder cache will not be initialized because cache size is 0
# the cache size. This is because the memory space for the encoder cache # for these models.
# is preallocated in the profiling run.
self.encoder_cache_manager = EncoderCacheManager( self.encoder_cache_manager = EncoderCacheManager(
cache_size=self.scheduler_config.encoder_cache_size) cache_size=encoder_cache_size)
def schedule(self) -> "SchedulerOutput": def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
......
...@@ -54,9 +54,12 @@ class EngineCore: ...@@ -54,9 +54,12 @@ class EngineCore:
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
# Setup scheduler. # Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config, self.scheduler = Scheduler(
vllm_config.cache_config, scheduler_config=vllm_config.scheduler_config,
vllm_config.lora_config) model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
)
self.mm_input_mapper_server = MMInputMapperServer( self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config) vllm_config.model_config)
......
...@@ -20,6 +20,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, ...@@ -20,6 +20,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
is_pin_memory_available) is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
...@@ -88,8 +89,12 @@ class GPUModelRunner: ...@@ -88,8 +89,12 @@ class GPUModelRunner:
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False self.mm_input_mapper_profiling.use_cache = False
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
self.encoder_cache_size = self.scheduler_config.encoder_cache_size model_config=model_config,
scheduler_config=scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# Lazy initialization # Lazy initialization
# self.model: nn.Module # Set after load_model # self.model: nn.Module # Set after load_model
...@@ -721,44 +726,30 @@ class GPUModelRunner: ...@@ -721,44 +726,30 @@ class GPUModelRunner:
] ]
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
if self.is_multimodal_model: # TODO: handle encoder-decoder models once we support them.
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
# Create dummy batch of multimodal inputs. and self.encoder_cache_size > 0):
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
# it supports multiple. # it supports multiple.
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501 max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
self.model_config) self.model_config)
dummy_data_modality, max_tokens_per_mm_item = max( dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1]) max_tokens_by_modality_dict.items(), key=lambda item: item[1])
# Check how many items of this modality can be supported by # Check how many items of this modality can be supported by
# the encoder cache budget. # the encoder budget.
encoder_cache_budget = min(self.max_num_encoder_input_tokens, encoder_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size) self.encoder_cache_size)
max_num_mm_items_encoder_budget = encoder_cache_budget // \
max_tokens_per_mm_item
# TODO: Allow users to set encoder_cache_budget in case this max_num_mm_items_encoder_budget = cdiv(encoder_budget,
# happens. max_tokens_per_mm_item)
assert max_num_mm_items_encoder_budget > 0, (
f"Encoder cache budget={encoder_cache_budget} is too small to "
f"support the maximum possible size of multimodal embeddings"
f"={max_tokens_per_mm_item}.")
# Check how many items of this modality can be supported by # Check how many items of this modality can be supported by
# the decoder budget. # the decoder budget.
max_mm_items_per_req = max( max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
self.mm_registry.get_mm_limits_per_prompt( self.model_config)[dummy_data_modality]
self.model_config).values())
# NOTE: We do not consider max_num_batched_tokens on purpose # NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance # because the multimodal embeddings can be generated in advance
...@@ -769,6 +760,19 @@ class GPUModelRunner: ...@@ -769,6 +760,19 @@ class GPUModelRunner:
max_num_mm_items = min(max_num_mm_items_encoder_budget, max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget) max_num_mm_items_decoder_budget)
logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_budget, max_num_mm_items, dummy_data_modality)
# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
# Dummy data definition in V0 may contain multiple multimodal items # Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we # (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1 # always replicate first item by max_num_mm_items times since in V1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment