Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c9280e63
Unverified
Commit
c9280e63
authored
Jun 12, 2025
by
jmswen
Committed by
GitHub
Jun 12, 2025
Browse files
[Bugfix] Respect num-gpu-blocks-override in v1 (#19503)
Signed-off-by:
Jon Swenson
<
jmswen@gmail.com
>
parent
af09b3f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
0 deletions
+17
-0
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+16
-0
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+1
-0
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
c9280e63
...
@@ -900,3 +900,19 @@ def test_get_kv_cache_config():
...
@@ -900,3 +900,19 @@ def test_get_kv_cache_config():
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_hybrid
,
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_hybrid
,
mem_per_block_per_layer
*
2
*
32
)
mem_per_block_per_layer
*
2
*
32
)
# Test num_gpu_blocks_override
vllm_config
.
cache_config
.
num_gpu_blocks_override
=
16
kv_cache_config_override_blocks
=
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_full
,
mem_per_block_per_layer
*
2
*
32
)
assert
kv_cache_config_override_blocks
==
KVCacheConfig
(
num_blocks
=
16
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
16
,
shared_by
=
[
"layer_1"
]),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
16
,
shared_by
=
[
"layer_2"
]),
],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer_1"
,
"layer_2"
],
new_kv_cache_spec
())
])
\ No newline at end of file
vllm/v1/core/kv_cache_utils.py
View file @
c9280e63
...
@@ -660,6 +660,7 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
...
@@ -660,6 +660,7 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
logger
.
info
(
logger
.
info
(
"Overriding num_gpu_blocks=%d with "
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d"
,
num_blocks
,
num_gpu_blocks_override
)
"num_gpu_blocks_override=%d"
,
num_blocks
,
num_gpu_blocks_override
)
num_blocks
=
num_gpu_blocks_override
return
num_blocks
return
num_blocks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment