Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e9438e0
Unverified
Commit
1e9438e0
authored
Jul 14, 2025
by
wangxiyuan
Committed by
GitHub
Jul 14, 2025
Browse files
[MISC] Move bind_kv_cache to worker module (#20900)
Signed-off-by:
wangxiyuan
<
wangxiyuan1007@gmail.com
>
parent
697ef765
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
57 additions
and
55 deletions
+57
-55
tests/v1/test_utils.py
tests/v1/test_utils.py
+1
-1
vllm/v1/utils.py
vllm/v1/utils.py
+0
-48
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+1
-2
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+2
-1
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+51
-1
No files found.
tests/v1/test_utils.py
View file @
1e9438e0
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
torch
import
torch
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.
worker.
utils
import
bind_kv_cache
def
test_bind_kv_cache
():
def
test_bind_kv_cache
():
...
...
vllm/v1/utils.py
View file @
1e9438e0
...
@@ -4,7 +4,6 @@ import argparse
...
@@ -4,7 +4,6 @@ import argparse
import
multiprocessing
import
multiprocessing
import
time
import
time
import
weakref
import
weakref
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
multiprocessing
import
connection
from
multiprocessing
import
connection
from
multiprocessing.process
import
BaseProcess
from
multiprocessing.process
import
BaseProcess
...
@@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
...
@@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
import
torch
import
torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
(
get_open_port
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
from
vllm.utils
import
(
get_open_port
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
kill_process_tree
)
kill_process_tree
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.engine.utils
import
(
CoreEngineActorManager
,
from
vllm.v1.engine.utils
import
(
CoreEngineActorManager
,
CoreEngineProcManager
)
CoreEngineProcManager
)
...
@@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
...
@@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree
(
pid
)
kill_process_tree
(
pid
)
def
bind_kv_cache
(
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
forward_context
:
dict
[
str
,
"Attention"
],
runner_kv_caches
:
list
[
torch
.
Tensor
],
)
->
None
:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert
len
(
runner_kv_caches
)
==
0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name
=
defaultdict
(
list
)
for
layer_name
in
kv_caches
:
index2name
[
extract_layer_index
(
layer_name
)].
append
(
layer_name
)
for
layer_index
in
sorted
(
index2name
.
keys
()):
layer_names
=
index2name
[
layer_index
]
if
len
(
layer_names
)
>
1
:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise
NotImplementedError
layer_name
=
layer_names
[
0
]
runner_kv_caches
.
append
(
kv_caches
[
layer_name
])
# Bind kv_caches to forward context
for
layer_name
,
kv_cache
in
kv_caches
.
items
():
# NOTE: Use list because of v0 PP virtual engine.
forward_context
[
layer_name
].
kv_cache
=
[
kv_cache
]
def
copy_slice
(
from_tensor
:
torch
.
Tensor
,
to_tensor
:
torch
.
Tensor
,
def
copy_slice
(
from_tensor
:
torch
.
Tensor
,
to_tensor
:
torch
.
Tensor
,
length
:
int
)
->
torch
.
Tensor
:
length
:
int
)
->
torch
.
Tensor
:
"""
"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
1e9438e0
...
@@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer
...
@@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
..sample.logits_processor
import
LogitsProcessorManager
from
..sample.logits_processor
import
LogitsProcessorManager
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
from
.utils
import
(
bind_kv_cache
,
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
1e9438e0
...
@@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
...
@@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
LogprobsTensors
,
ModelRunnerOutput
)
LogprobsTensors
,
ModelRunnerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.tpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.tpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
(
initialize_kv_cache_for_kv_sharing
,
from
.utils
import
(
bind_kv_cache
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
)
sanity_check_mm_encoder_outputs
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
...
vllm/v1/worker/tpu_worker.py
View file @
1e9438e0
...
@@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
...
@@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
KVCacheSpec
)
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.utils
import
bind_kv_cache
,
report_usage_stats
from
vllm.v1.utils
import
report_usage_stats
from
vllm.v1.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.v1.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.v1.worker.utils
import
bind_kv_cache
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/worker/utils.py
View file @
1e9438e0
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
from
vllm.model_executor.models.interfaces
import
MultiModalEmbeddings
from
vllm.model_executor.models.interfaces
import
MultiModalEmbeddings
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.v1.kv_cache_interface
import
KVCacheGroupSpec
from
vllm.v1.kv_cache_interface
import
KVCacheGroupSpec
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
def
sanity_check_mm_encoder_outputs
(
def
sanity_check_mm_encoder_outputs
(
mm_embeddings
:
MultiModalEmbeddings
,
mm_embeddings
:
MultiModalEmbeddings
,
...
@@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
...
@@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
kv_caches
[
layer_name
]
=
kv_caches
[
target_layer_name
]
kv_caches
[
layer_name
]
=
kv_caches
[
target_layer_name
]
group_idx
=
layer_to_kv_cache_group_idx
[
target_layer_name
]
group_idx
=
layer_to_kv_cache_group_idx
[
target_layer_name
]
kv_cache_groups
[
group_idx
].
layer_names
.
append
(
layer_name
)
kv_cache_groups
[
group_idx
].
layer_names
.
append
(
layer_name
)
def
bind_kv_cache
(
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
forward_context
:
dict
[
str
,
"Attention"
],
runner_kv_caches
:
list
[
torch
.
Tensor
],
)
->
None
:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert
len
(
runner_kv_caches
)
==
0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name
=
defaultdict
(
list
)
for
layer_name
in
kv_caches
:
index2name
[
extract_layer_index
(
layer_name
)].
append
(
layer_name
)
for
layer_index
in
sorted
(
index2name
.
keys
()):
layer_names
=
index2name
[
layer_index
]
if
len
(
layer_names
)
>
1
:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise
NotImplementedError
layer_name
=
layer_names
[
0
]
runner_kv_caches
.
append
(
kv_caches
[
layer_name
])
# Bind kv_caches to forward context
for
layer_name
,
kv_cache
in
kv_caches
.
items
():
# NOTE: Use list because of v0 PP virtual engine.
forward_context
[
layer_name
].
kv_cache
=
[
kv_cache
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment