Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e9438e0
Unverified
Commit
1e9438e0
authored
Jul 14, 2025
by
wangxiyuan
Committed by
GitHub
Jul 14, 2025
Browse files
[MISC] Move bind_kv_cache to worker module (#20900)
Signed-off-by:
wangxiyuan
<
wangxiyuan1007@gmail.com
>
parent
697ef765
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
57 additions
and
55 deletions
+57
-55
tests/v1/test_utils.py
tests/v1/test_utils.py
+1
-1
vllm/v1/utils.py
vllm/v1/utils.py
+0
-48
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+1
-2
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+2
-1
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+51
-1
No files found.
tests/v1/test_utils.py
View file @
1e9438e0
...
...
@@ -3,7 +3,7 @@
import
torch
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.
worker.
utils
import
bind_kv_cache
def
test_bind_kv_cache
():
...
...
vllm/v1/utils.py
View file @
1e9438e0
...
...
@@ -4,7 +4,6 @@ import argparse
import
multiprocessing
import
time
import
weakref
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
multiprocessing
import
connection
from
multiprocessing.process
import
BaseProcess
...
...
@@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
(
get_open_port
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
kill_process_tree
)
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.engine.utils
import
(
CoreEngineActorManager
,
CoreEngineProcManager
)
...
...
@@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree
(
pid
)
def
bind_kv_cache
(
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
forward_context
:
dict
[
str
,
"Attention"
],
runner_kv_caches
:
list
[
torch
.
Tensor
],
)
->
None
:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert
len
(
runner_kv_caches
)
==
0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name
=
defaultdict
(
list
)
for
layer_name
in
kv_caches
:
index2name
[
extract_layer_index
(
layer_name
)].
append
(
layer_name
)
for
layer_index
in
sorted
(
index2name
.
keys
()):
layer_names
=
index2name
[
layer_index
]
if
len
(
layer_names
)
>
1
:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise
NotImplementedError
layer_name
=
layer_names
[
0
]
runner_kv_caches
.
append
(
kv_caches
[
layer_name
])
# Bind kv_caches to forward context
for
layer_name
,
kv_cache
in
kv_caches
.
items
():
# NOTE: Use list because of v0 PP virtual engine.
forward_context
[
layer_name
].
kv_cache
=
[
kv_cache
]
def
copy_slice
(
from_tensor
:
torch
.
Tensor
,
to_tensor
:
torch
.
Tensor
,
length
:
int
)
->
torch
.
Tensor
:
"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
1e9438e0
...
...
@@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
..sample.logits_processor
import
LogitsProcessorManager
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
from
.utils
import
(
bind_kv_cache
,
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
1e9438e0
...
...
@@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
LogprobsTensors
,
ModelRunnerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.tpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
(
initialize_kv_cache_for_kv_sharing
,
from
.utils
import
(
bind_kv_cache
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
)
if
TYPE_CHECKING
:
...
...
vllm/v1/worker/tpu_worker.py
View file @
1e9438e0
...
...
@@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.utils
import
bind_kv_cache
,
report_usage_stats
from
vllm.v1.utils
import
report_usage_stats
from
vllm.v1.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.v1.worker.utils
import
bind_kv_cache
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/worker/utils.py
View file @
1e9438e0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
vllm.model_executor.models.interfaces
import
MultiModalEmbeddings
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.v1.kv_cache_interface
import
KVCacheGroupSpec
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
def
sanity_check_mm_encoder_outputs
(
mm_embeddings
:
MultiModalEmbeddings
,
...
...
@@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
kv_caches
[
layer_name
]
=
kv_caches
[
target_layer_name
]
group_idx
=
layer_to_kv_cache_group_idx
[
target_layer_name
]
kv_cache_groups
[
group_idx
].
layer_names
.
append
(
layer_name
)
def
bind_kv_cache
(
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
forward_context
:
dict
[
str
,
"Attention"
],
runner_kv_caches
:
list
[
torch
.
Tensor
],
)
->
None
:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert
len
(
runner_kv_caches
)
==
0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name
=
defaultdict
(
list
)
for
layer_name
in
kv_caches
:
index2name
[
extract_layer_index
(
layer_name
)].
append
(
layer_name
)
for
layer_index
in
sorted
(
index2name
.
keys
()):
layer_names
=
index2name
[
layer_index
]
if
len
(
layer_names
)
>
1
:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise
NotImplementedError
layer_name
=
layer_names
[
0
]
runner_kv_caches
.
append
(
kv_caches
[
layer_name
])
# Bind kv_caches to forward context
for
layer_name
,
kv_cache
in
kv_caches
.
items
():
# NOTE: Use list because of v0 PP virtual engine.
forward_context
[
layer_name
].
kv_cache
=
[
kv_cache
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment