Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
65b2f405
Unverified
Commit
65b2f405
authored
Mar 10, 2026
by
Nick Hill
Committed by
GitHub
Mar 10, 2026
Browse files
[Core] Simplify core kv-cache blocks initialization logic (#36521)
Signed-off-by:
Nick Hill
<
nickhill123@gmail.com
>
parent
2a68464c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
37 deletions
+28
-37
tests/models/test_initialization.py
tests/models/test_initialization.py
+8
-2
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+11
-18
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+9
-13
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+0
-4
No files found.
tests/models/test_initialization.py
View file @
65b2f405
...
@@ -88,9 +88,15 @@ def can_initialize(
...
@@ -88,9 +88,15 @@ def can_initialize(
[
10
*
GiB_bytes
],
[
10
*
GiB_bytes
],
)
)
scheduler_kv_cache_config
=
generate_scheduler_kv_cache_config
(
kv_cache_configs
)
scheduler_kv_cache_config
=
generate_scheduler_kv_cache_config
(
kv_cache_configs
)
vllm_config
.
cache_config
.
num_gpu_blocks
=
scheduler_kv_cache_config
.
num_blocks
kv_cache_groups
=
scheduler_kv_cache_config
.
kv_cache_groups
if
kv_cache_groups
:
vllm_config
.
cache_config
.
block_size
=
min
(
g
.
kv_cache_spec
.
block_size
for
g
in
kv_cache_groups
)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
vllm_config
.
validate_block_size
()
return
1
,
0
,
scheduler_kv_cache_config
return
scheduler_kv_cache_config
if
model_arch
==
"MiniMaxVL01ForConditionalGeneration"
:
if
model_arch
==
"MiniMaxVL01ForConditionalGeneration"
:
pytest
.
skip
(
pytest
.
skip
(
...
...
vllm/v1/engine/core.py
View file @
65b2f405
...
@@ -117,18 +117,7 @@ class EngineCore:
...
@@ -117,18 +117,7 @@ class EngineCore:
self
.
_eep_scale_up_before_kv_init
()
self
.
_eep_scale_up_before_kv_init
()
# Setup KV Caches and update CacheConfig after profiling.
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks
,
num_cpu_blocks
,
kv_cache_config
=
self
.
_initialize_kv_caches
(
kv_cache_config
=
self
.
_initialize_kv_caches
(
vllm_config
)
vllm_config
)
if
kv_cache_config
.
kv_cache_groups
:
vllm_config
.
cache_config
.
block_size
=
min
(
g
.
kv_cache_spec
.
block_size
for
g
in
kv_cache_config
.
kv_cache_groups
)
vllm_config
.
validate_block_size
()
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
vllm_config
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
num_gpu_blocks
,
num_cpu_blocks
))
self
.
structured_output_manager
=
StructuredOutputManager
(
vllm_config
)
self
.
structured_output_manager
=
StructuredOutputManager
(
vllm_config
)
# Setup scheduler.
# Setup scheduler.
...
@@ -233,9 +222,7 @@ class EngineCore:
...
@@ -233,9 +222,7 @@ class EngineCore:
enable_envs_cache
()
enable_envs_cache
()
@
instrument
(
span_name
=
"Prepare model"
)
@
instrument
(
span_name
=
"Prepare model"
)
def
_initialize_kv_caches
(
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheConfig
:
self
,
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
,
KVCacheConfig
]:
start
=
time
.
time
()
start
=
time
.
time
()
# Get all kv cache needed by the model
# Get all kv cache needed by the model
...
@@ -276,8 +263,14 @@ class EngineCore:
...
@@ -276,8 +263,14 @@ class EngineCore:
self
.
collective_rpc
(
"update_max_model_len"
,
args
=
(
max_model_len_after
,))
self
.
collective_rpc
(
"update_max_model_len"
,
args
=
(
max_model_len_after
,))
scheduler_kv_cache_config
=
generate_scheduler_kv_cache_config
(
kv_cache_configs
)
scheduler_kv_cache_config
=
generate_scheduler_kv_cache_config
(
kv_cache_configs
)
num_gpu_blocks
=
scheduler_kv_cache_config
.
num_blocks
vllm_config
.
cache_config
.
num_gpu_blocks
=
scheduler_kv_cache_config
.
num_blocks
num_cpu_blocks
=
0
kv_cache_groups
=
scheduler_kv_cache_config
.
kv_cache_groups
if
kv_cache_groups
:
vllm_config
.
cache_config
.
block_size
=
min
(
g
.
kv_cache_spec
.
block_size
for
g
in
kv_cache_groups
)
vllm_config
.
validate_block_size
()
# Initialize kv cache and warmup the execution
# Initialize kv cache and warmup the execution
self
.
model_executor
.
initialize_from_config
(
kv_cache_configs
)
self
.
model_executor
.
initialize_from_config
(
kv_cache_configs
)
...
@@ -288,7 +281,7 @@ class EngineCore:
...
@@ -288,7 +281,7 @@ class EngineCore:
elapsed
,
elapsed
,
scope
=
"local"
,
scope
=
"local"
,
)
)
return
num_gpu_blocks
,
num_cpu_blocks
,
scheduler_kv_cache_config
return
scheduler_kv_cache_config
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
return
self
.
model_executor
.
supported_tasks
return
self
.
model_executor
.
supported_tasks
...
...
vllm/v1/worker/gpu_worker.py
View file @
65b2f405
...
@@ -203,7 +203,9 @@ class Worker(WorkerBase):
...
@@ -203,7 +203,9 @@ class Worker(WorkerBase):
self
.
model_runner
.
init_fp8_kv_scales
()
self
.
model_runner
.
init_fp8_kv_scales
()
def
_maybe_get_memory_pool_context
(
self
,
tag
:
str
)
->
AbstractContextManager
:
def
_maybe_get_memory_pool_context
(
self
,
tag
:
str
)
->
AbstractContextManager
:
if
self
.
vllm_config
.
model_config
.
enable_sleep_mode
:
if
not
self
.
vllm_config
.
model_config
.
enable_sleep_mode
:
return
nullcontext
()
from
vllm.device_allocator.cumem
import
CuMemAllocator
from
vllm.device_allocator.cumem
import
CuMemAllocator
allocator
=
CuMemAllocator
.
get_instance
()
allocator
=
CuMemAllocator
.
get_instance
()
...
@@ -212,12 +214,6 @@ class Worker(WorkerBase):
...
@@ -212,12 +214,6 @@ class Worker(WorkerBase):
"Sleep mode can only be used for one instance per process."
"Sleep mode can only be used for one instance per process."
)
)
return
allocator
.
use_memory_pool
(
tag
=
tag
)
return
allocator
.
use_memory_pool
(
tag
=
tag
)
else
:
return
nullcontext
()
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
@
instrument
(
span_name
=
"Init device"
)
@
instrument
(
span_name
=
"Init device"
)
def
init_device
(
self
):
def
init_device
(
self
):
...
...
vllm/v1/worker/worker_base.py
View file @
65b2f405
...
@@ -104,10 +104,6 @@ class WorkerBase:
...
@@ -104,10 +104,6 @@ class WorkerBase:
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache with the given size in blocks."""
raise
NotImplementedError
def
reset_mm_cache
(
self
)
->
None
:
def
reset_mm_cache
(
self
)
->
None
:
reset_fn
=
getattr
(
self
.
model_runner
,
"reset_mm_cache"
,
None
)
reset_fn
=
getattr
(
self
.
model_runner
,
"reset_mm_cache"
,
None
)
if
callable
(
reset_fn
):
if
callable
(
reset_fn
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment