Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4716377f
Unverified
Commit
4716377f
authored
Apr 09, 2025
by
rongfu.leng
Committed by
GitHub
Apr 08, 2025
Browse files
[Feature] Estimate max-model-len use available KV cache memory (#16168)
Signed-off-by:
rongfu.leng
<
rongfu.leng@daocloud.io
>
parent
4e9cf8c1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
5 deletions
+106
-5
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+45
-1
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+61
-4
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
4716377f
...
...
@@ -3,14 +3,16 @@
import
pytest
import
torch
from
vllm.config
import
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
sha256
from
vllm.utils
import
GiB_bytes
,
sha256
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from
vllm.v1.core.kv_cache_utils
import
(
NONE_HASH
,
BlockHashType
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
hash_block_tokens
,
hash_request_tokens
,
...
...
@@ -426,3 +428,45 @@ def test_unify_kv_cache_configs():
]
with
pytest
.
raises
(
AssertionError
):
unify_kv_cache_configs
(
diff_kv_cache_config
)
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"max_model_len"
,
"want_estimated_max_len"
),
[
(
"Qwen/Qwen1.5-7B"
,
16385
,
16384
),
(
"Qwen/Qwen1.5-7B"
,
16383
,
16383
),
])
def
test_estimate_max_model_len
(
model_id
,
max_model_len
,
want_estimated_max_len
):
# Create a VllmConfig
model_config
=
ModelConfig
(
model_id
,
task
=
"generate"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"float16"
,
max_model_len
=
max_model_len
,
)
scheduler_config
=
SchedulerConfig
(
max_num_batched_tokens
=
32768
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
scheduler_config
=
scheduler_config
,
)
# Create KV cache specs
kv_cache_spec
=
{}
for
i
in
range
(
32
):
layer_name
=
f
"layer_
{
i
}
"
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
16
,
num_kv_heads
=
32
,
head_size
=
128
,
dtype
=
torch
.
float16
,
use_mla
=
False
,
)
# Estimate the maximum model length, 16384 model_len need 8GB
estimated_max_len
=
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
8
*
GiB_bytes
)
assert
estimated_max_len
==
want_estimated_max_len
vllm/v1/core/kv_cache_utils.py
View file @
4716377f
...
...
@@ -8,7 +8,7 @@ from typing import Any, Callable, NamedTuple, Optional
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
sha256
from
vllm.utils
import
GiB_bytes
,
sha256
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
...
...
@@ -459,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int,
return
ret
def
estimate_max_model_len
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
int
:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The estimated maximum model length that can fit in the available memory.
"""
# Define a function to check if a given model length fits in memory
def
fits_in_memory
(
model_len
:
int
)
->
bool
:
# Modify the max_model_len for this calculation
vllm_config
.
model_config
.
max_model_len
=
model_len
# Calculate memory needed for the given model length
memory_needed
=
sum
(
(
layer_spec
.
max_memory_usage_bytes
(
vllm_config
)
for
layer_spec
in
kv_cache_spec
.
values
()),
start
=
0
,
)
return
memory_needed
<=
available_memory
# Binary search for the maximum model length
current_max
=
vllm_config
.
model_config
.
max_model_len
left
,
right
=
1
,
current_max
# If even the smallest model length doesn't fit, return 0
if
not
fits_in_memory
(
left
):
return
0
# Binary search for the maximum model length that fits
result
=
1
while
left
<=
right
:
mid
=
(
left
+
right
)
//
2
if
fits_in_memory
(
mid
):
result
=
mid
left
=
mid
+
1
else
:
right
=
mid
-
1
return
result
def
check_enough_kv_cache_memory
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
):
...
...
@@ -486,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
needed_memory
+=
layer_spec
.
max_memory_usage_bytes
(
vllm_config
)
if
needed_memory
>
available_memory
:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len
=
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
available_memory
)
estimated_msg
=
""
if
estimated_max_len
>
0
:
estimated_msg
=
" Based on the available memory,"
f
" the estimated maximum model length is
{
estimated_max_len
}
."
raise
ValueError
(
f
"To serve at least one request with the models's max seq len "
f
"(
{
max_model_len
}
), (
{
needed_memory
/
1024
/
1024
/
1024
:.
2
f
}
GiB KV "
f
"(
{
max_model_len
}
), (
{
needed_memory
/
GiB_bytes
:.
2
f
}
GiB KV "
f
"cache is needed, which is larger than the available KV cache "
f
"memory (
{
available_memory
/
1024
/
1024
/
1024
:.
2
f
}
GiB). Try "
f
"increasing `gpu_memory_utilization` or decreasing "
f
"memory (
{
available_memory
/
GiB_bytes
:.
2
f
}
GiB)."
f
"
{
estimated_msg
}
"
f
" Try increasing `gpu_memory_utilization` or decreasing "
f
"`max_model_len` when initializing the engine."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment