Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
70755e81
Unverified
Commit
70755e81
authored
Jan 15, 2025
by
Roger Wang
Committed by
GitHub
Jan 15, 2025
Browse files
[V1][Core] Autotune encoder cache budget (#11895)
Signed-off-by:
Roger Wang
<
ywang@roblox.com
>
parent
edce722e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
167 additions
and
50 deletions
+167
-50
vllm/config.py
vllm/config.py
+10
-5
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+24
-5
vllm/v1/core/encoder_cache_manager.py
vllm/v1/core/encoder_cache_manager.py
+77
-1
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+18
-8
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+6
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+32
-28
No files found.
vllm/config.py
View file @
70755e81
...
...
@@ -1387,13 +1387,15 @@ class SchedulerConfig:
is_multimodal_model
:
bool
=
False
# FIXME(woosuk & ywang96): Below are placeholder values. We need to
# calculate the actual values from the configurations.
# Multimodal encoder run compute budget, only used in V1
max_num_encoder_input_tokens
=
16384
# NOTE: The following multimodal encoder budget will be initialized to
# max_num_batched_tokens and overridden in case max multimodal embedding
# size is larger.
# TODO (ywang96): Make these configurable.
# Multimodal encoder compute budget, only used in V1
max_num_encoder_input_tokens
:
int
=
field
(
default
=
None
)
# type: ignore
# Multimodal encoder cache size, only used in V1
encoder_cache_size
=
16384
encoder_cache_size
:
int
=
field
(
default
=
None
)
# type: ignore
# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
...
...
@@ -1467,6 +1469,9 @@ class SchedulerConfig:
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
,
)
self
.
max_num_encoder_input_tokens
=
self
.
max_num_batched_tokens
self
.
encoder_cache_size
=
self
.
max_num_batched_tokens
if
self
.
enable_chunked_prefill
:
logger
.
info
(
"Chunked prefill is enabled with max_num_batched_tokens=%d."
,
...
...
vllm/multimodal/registry.py
View file @
70755e81
...
...
@@ -252,11 +252,8 @@ class MultiModalRegistry:
model_config
:
"ModelConfig"
,
)
->
Mapping
[
str
,
int
]:
"""
Get the maximum number of tokens per data item from each modality
for profiling the memory usage of a model.
Note:
This is currently directly used only in V1.
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if
self
.
has_processor
(
model_config
):
tokenizer
=
cached_get_tokenizer
(
...
...
@@ -272,6 +269,28 @@ class MultiModalRegistry:
for
key
,
plugin
in
self
.
_plugins
.
items
()
}
def
get_max_tokens_per_item_by_nonzero_modality
(
self
,
model_config
:
"ModelConfig"
,
)
->
Mapping
[
str
,
int
]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.
Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
limits_per_plugin
=
self
.
_limits_by_model
[
model_config
]
return
{
key
:
max_tokens_per_mm_item
for
key
,
max_tokens_per_mm_item
in
self
.
get_max_tokens_per_item_by_modality
(
model_config
).
items
()
if
limits_per_plugin
[
key
]
>
0
}
def
get_max_tokens_by_modality
(
self
,
model_config
:
"ModelConfig"
,
...
...
vllm/v1/core/encoder_cache_manager.py
View file @
70755e81
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Set
,
Tuple
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.v1.request
import
Request
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
SchedulerConfig
logger
=
init_logger
(
__name__
)
class
EncoderCacheManager
:
...
...
@@ -46,3 +53,72 @@ class EncoderCacheManager:
freed
=
self
.
freed
self
.
freed
=
[]
return
freed
def
compute_encoder_budget
(
model_config
:
"ModelConfig"
,
scheduler_config
:
"SchedulerConfig"
,
)
->
Tuple
[
int
,
int
]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
if
not
model_config
.
is_multimodal_model
:
return
0
,
0
# TODO: handle encoder-decoder models once we support them.
(
encoder_compute_budget
,
encoder_cache_size
,
)
=
_compute_encoder_budget_multimodal
(
model_config
,
scheduler_config
)
return
encoder_compute_budget
,
encoder_cache_size
def
_compute_encoder_budget_multimodal
(
model_config
:
"ModelConfig"
,
scheduler_config
:
"SchedulerConfig"
,
)
->
Tuple
[
int
,
int
]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
max_tokens_by_modality_dict
=
MULTIMODAL_REGISTRY
.
get_max_tokens_per_item_by_nonzero_modality
(
# noqa: E501
model_config
)
if
not
max_tokens_by_modality_dict
:
logger
.
warning
(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized."
)
return
0
,
0
_
,
max_tokens_per_mm_item
=
max
(
max_tokens_by_modality_dict
.
items
(),
key
=
lambda
item
:
item
[
1
])
encoder_compute_budget
=
max
(
scheduler_config
.
max_num_encoder_input_tokens
,
max_tokens_per_mm_item
)
encoder_cache_size
=
max
(
scheduler_config
.
encoder_cache_size
,
max_tokens_per_mm_item
)
return
encoder_compute_budget
,
encoder_cache_size
vllm/v1/core/scheduler.py
View file @
70755e81
...
...
@@ -3,10 +3,11 @@ from dataclasses import dataclass
from
typing
import
(
TYPE_CHECKING
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.encoder_cache_manager
import
EncoderCacheManager
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
compute_encoder_budget
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.engine
import
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.metrics.stats
import
SchedulerStats
...
...
@@ -25,6 +26,7 @@ class Scheduler:
def
__init__
(
self
,
scheduler_config
:
SchedulerConfig
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
)
->
None
:
...
...
@@ -69,16 +71,24 @@ class Scheduler:
self
.
running_reqs_data
:
Dict
[
str
,
RunningRequestData
]
=
{}
# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
model_config
=
model_config
,
scheduler_config
=
scheduler_config
,
)
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
self
.
max_num_encoder_input_tokens
=
self
.
scheduler_config
.
max_num_encoder_input_tokens
#noqa: E501
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self
.
max_num_encoder_input_tokens
=
encoder_compute_budget
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self
.
encoder_cache_manager
=
EncoderCacheManager
(
cache_size
=
self
.
scheduler_config
.
encoder_cache_size
)
cache_size
=
encoder_cache_size
)
def
schedule
(
self
)
->
"SchedulerOutput"
:
# NOTE(woosuk) on the scheduling algorithm:
...
...
vllm/v1/engine/core.py
View file @
70755e81
...
...
@@ -54,9 +54,12 @@ class EngineCore:
vllm_config
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
# Setup scheduler.
self
.
scheduler
=
Scheduler
(
vllm_config
.
scheduler_config
,
vllm_config
.
cache_config
,
vllm_config
.
lora_config
)
self
.
scheduler
=
Scheduler
(
scheduler_config
=
vllm_config
.
scheduler_config
,
model_config
=
vllm_config
.
model_config
,
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
)
self
.
mm_input_mapper_server
=
MMInputMapperServer
(
vllm_config
.
model_config
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
70755e81
...
...
@@ -20,6 +20,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
FlashAttentionMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperClient
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
...
@@ -88,8 +89,12 @@ class GPUModelRunner:
self
.
mm_input_mapper_profiling
=
MMInputMapperClient
(
self
.
model_config
)
self
.
mm_input_mapper_profiling
.
use_cache
=
False
self
.
max_num_encoder_input_tokens
=
self
.
scheduler_config
.
max_num_encoder_input_tokens
# noqa: E501
self
.
encoder_cache_size
=
self
.
scheduler_config
.
encoder_cache_size
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
model_config
=
model_config
,
scheduler_config
=
scheduler_config
,
)
self
.
max_num_encoder_input_tokens
=
encoder_compute_budget
self
.
encoder_cache_size
=
encoder_cache_size
# Lazy initialization
# self.model: nn.Module # Set after load_model
...
...
@@ -721,44 +726,30 @@ class GPUModelRunner:
]
# Profile with multimodal encoder & encoder cache.
if
self
.
is_multimodal_model
:
# Create dummy batch of multimodal inputs.
dummy_request_data
=
self
.
input_registry
.
dummy_data_for_profiling
(
model_config
=
self
.
model_config
,
seq_len
=
self
.
max_num_tokens
,
mm_registry
=
self
.
mm_registry
,
)
dummy_mm_data
=
dummy_request_data
.
multi_modal_data
# TODO: handle encoder-decoder models once we support them.
if
(
self
.
is_multimodal_model
and
self
.
max_num_encoder_input_tokens
>
0
and
self
.
encoder_cache_size
>
0
):
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict
=
self
.
mm_registry
.
get_max_tokens_per_item_by_modality
(
# noqa: E501
max_tokens_by_modality_dict
=
MULTIMODAL_REGISTRY
.
get_max_tokens_per_item_by_
nonzero_
modality
(
# noqa: E501
self
.
model_config
)
dummy_data_modality
,
max_tokens_per_mm_item
=
max
(
max_tokens_by_modality_dict
.
items
(),
key
=
lambda
item
:
item
[
1
])
# Check how many items of this modality can be supported by
# the encoder cache budget.
encoder_cache_budget
=
min
(
self
.
max_num_encoder_input_tokens
,
self
.
encoder_cache_size
)
max_num_mm_items_encoder_budget
=
encoder_cache_budget
//
\
max_tokens_per_mm_item
# TODO: Allow users to set encoder_cache_budget in case this
# happens.
assert
max_num_mm_items_encoder_budget
>
0
,
(
f
"Encoder cache budget=
{
encoder_cache_budget
}
is too small to "
f
"support the maximum possible size of multimodal embeddings"
f
"=
{
max_tokens_per_mm_item
}
."
)
# the encoder budget.
encoder_budget
=
min
(
self
.
max_num_encoder_input_tokens
,
self
.
encoder_cache_size
)
max_num_mm_items_encoder_budget
=
cdiv
(
encoder_budget
,
max_tokens_per_mm_item
)
# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req
=
max
(
self
.
mm_registry
.
get_mm_limits_per_prompt
(
self
.
model_config
).
values
())
max_mm_items_per_req
=
self
.
mm_registry
.
get_mm_limits_per_prompt
(
self
.
model_config
)[
dummy_data_modality
]
# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
...
...
@@ -769,6 +760,19 @@ class GPUModelRunner:
max_num_mm_items
=
min
(
max_num_mm_items_encoder_budget
,
max_num_mm_items_decoder_budget
)
logger
.
info
(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size."
,
encoder_budget
,
max_num_mm_items
,
dummy_data_modality
)
# Create dummy batch of multimodal inputs.
dummy_request_data
=
self
.
input_registry
.
dummy_data_for_profiling
(
model_config
=
self
.
model_config
,
seq_len
=
self
.
max_num_tokens
,
mm_registry
=
self
.
mm_registry
,
)
dummy_mm_data
=
dummy_request_data
.
multi_modal_data
# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment