Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
93a00d7d
Unverified
Commit
93a00d7d
authored
Mar 21, 2025
by
Chen Zhang
Committed by
GitHub
Mar 21, 2025
Browse files
[v1] Refactor KVCacheConfig (#14079)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
61e8c183
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
320 additions
and
112 deletions
+320
-112
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+109
-1
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+95
-35
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+22
-9
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+5
-8
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+34
-18
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+29
-17
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-1
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+23
-21
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+1
-1
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+1
-1
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
93a00d7d
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -8,7 +9,10 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock
,
PrefixCachingMetrics
,
generate_block_hash_extra_keys
,
hash_block_tokens
,
hash_request_tokens
)
hash_request_tokens
,
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
...
...
@@ -314,3 +318,107 @@ def test_metrics():
assert
metrics
.
aggregated_query_total
==
0
assert
metrics
.
aggregated_query_hit
==
0
assert
not
metrics
.
query_queue
def
test_unify_kv_cache_configs
():
def
new_kv_cache_spec
(
block_size
=
16
,
num_kv_heads
=
2
,
head_size
=
64
,
dtype
=
torch
.
float32
,
use_mla
=
False
):
return
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
dtype
=
dtype
,
use_mla
=
use_mla
)
same_kv_cache_config
=
[
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
],
),
KVCacheConfig
(
num_blocks
=
20
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
],
),
]
unify_kv_cache_configs
(
same_kv_cache_config
)
assert
same_kv_cache_config
[
0
].
num_blocks
==
10
assert
same_kv_cache_config
[
1
].
num_blocks
==
10
need_sort_kv_cache_config
=
[
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
],
),
KVCacheConfig
(
num_blocks
=
20
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
],
),
]
unify_kv_cache_configs
(
need_sort_kv_cache_config
)
assert
need_sort_kv_cache_config
[
0
].
num_blocks
==
10
assert
need_sort_kv_cache_config
[
1
].
num_blocks
==
10
diff_kv_cache_config
=
[
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
],
),
KVCacheConfig
(
num_blocks
=
20
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
8
)),
],
),
]
with
pytest
.
raises
(
AssertionError
):
unify_kv_cache_configs
(
diff_kv_cache_config
)
vllm/v1/core/kv_cache_utils.py
View file @
93a00d7d
...
...
@@ -7,8 +7,8 @@ from typing import Any, NamedTuple, Optional
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.kv_cache_interface
import
(
KVCacheConfig
,
KVCacheSpec
,
KVCacheTensor
)
from
vllm.v1.kv_cache_interface
import
(
KVCacheConfig
,
KVCache
Group
Spec
,
KVCacheSpec
,
KVCacheTensor
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
...
...
@@ -449,7 +449,7 @@ def hash_request_tokens(block_size: int,
def
check_enough_kv_cache_memory
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
,
available_memory
:
int
):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
...
...
@@ -457,7 +457,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec: The kv cache spec of
each attention layer in
the model
available_memory: Memory available for KV cache in bytes.
Raises:
...
...
@@ -484,12 +484,43 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
f
"`max_model_len` when initializing the engine."
)
def
is_kv_cache_type_uniform
(
kv_cache_spec
:
KVCacheSpec
)
->
bool
:
def
create_kv_cache_group_specs
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
grouped_layer_names
:
list
[
list
[
str
]])
->
list
[
KVCacheGroupSpec
]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
kv_cache_groups
=
[]
for
layer_names_one_group
in
grouped_layer_names
:
layer_spec
=
kv_cache_spec
[
layer_names_one_group
[
0
]]
assert
all
(
kv_cache_spec
[
layer_name
]
==
layer_spec
for
layer_name
in
layer_names_one_group
[
1
:]),
(
"All layers in the same KV cache group must share the same "
"KVCacheSpec."
)
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
layer_names_one_group
,
layer_spec
))
return
kv_cache_groups
def
is_kv_cache_type_uniform
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
])
->
bool
:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.
Args:
kv_cache_spec: The
KVC
ache
S
pec of the model
kv_cache_spec: The
kv c
ache
s
pec of
each attention layer in
the model
Returns:
True if all layers have the same type, False otherwise.
...
...
@@ -500,18 +531,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
def
_get_kv_cache_config_uniform_type
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
KVCacheSpec
,
available_memory
:
int
,
num_layers
:
int
)
->
KVCacheConfig
:
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
KVCacheConfig
:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec: The kv cache spec of
each attention layer in
the model
available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.
Returns:
The generated KVCacheConfig
...
...
@@ -521,7 +550,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert
len
(
page_sizes
)
==
1
page_size
=
page_sizes
.
pop
()
num_blocks
=
int
(
available_memory
//
page_size
//
num_layers
)
num_blocks
=
int
(
available_memory
//
page_size
//
len
(
kv_cache_spec
)
)
num_blocks
=
max
(
num_blocks
,
0
)
if
vllm_config
.
cache_config
.
num_gpu_blocks_override
is
not
None
:
...
...
@@ -541,6 +570,9 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
max_model_len_str
,
max_concurrency
)
per_layer_size
=
page_size
*
num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names
=
[
list
(
kv_cache_spec
.
keys
())]
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
...
...
@@ -548,41 +580,69 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
layer_name
:
KVCacheTensor
(
size
=
per_layer_size
)
for
layer_name
in
kv_cache_spec
},
groups
=
[[
layer_name
for
layer_name
in
kv_cache_spec
]],
kv_cache_spec
=
kv_cache_spec
)
kv_cache_groups
=
create_kv_cache_group_specs
(
kv_cache_spec
,
grouped_layer_names
),
)
return
kv_cache_config
def
get_kv_cache_config
s
(
vllm_config
:
VllmConfig
,
kv_cache_spec
s
:
list
[
KVCacheSpec
],
available_memory
:
int
)
->
list
[
KVCacheConfig
]
:
def
get_kv_cache_config
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
KVCacheConfig
:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
Args:
vllm_config: The global VllmConfig
kv_cache_spec
s
: The kv cache spec
s
of the model
kv_cache_spec: The kv cache spec of
each attention layer in
the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfigs
"""
# Use the max number of layers to conservatively determine
# the number of blocks.
num_layers
=
max
(
len
(
kv_cache_spec
)
for
kv_cache_spec
in
kv_cache_specs
)
kv_cache_configs
=
[]
for
kv_cache_spec
in
kv_cache_specs
:
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs
.
append
(
_get_kv_cache_config_uniform_type
(
vllm_config
,
kv_cache_spec
,
available_memory
,
num_layers
))
else
:
return
_get_kv_cache_config_uniform_type
(
vllm_config
,
kv_cache_spec
,
available_memory
)
raise
NotImplementedError
def
unify_kv_cache_configs
(
kv_cache_configs
:
list
[
KVCacheConfig
]):
"""
Make the KV cache configurations for each worker consistent, so that all
workers can be controlled by the same KVCacheManager.
This function verifies that the layer group of each worker are the same,
and changes the num_blocks of each worker to the smallest among all workers.
Args:
kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent.
"""
# Sort the kv cache groups by the type_id of their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for
kv_cache_config
in
kv_cache_configs
:
kv_cache_config
.
kv_cache_groups
.
sort
(
key
=
lambda
x
:
x
.
kv_cache_spec
.
type_id
)
# Verify that the groups of each rank are the same.
for
kv_cache_config
in
kv_cache_configs
[
1
:]:
for
group_rank_0
,
group_rank_i
in
zip
(
kv_cache_configs
[
0
].
kv_cache_groups
,
kv_cache_config
.
kv_cache_groups
):
assert
group_rank_0
.
kv_cache_spec
==
group_rank_i
.
kv_cache_spec
# Change the num_blocks of each rank to the smallest among all ranks. We
# do not need to shrink the tensor size because it is valid to only use the
# first `num_blocks` blocks of the tensor.
min_num_blocks
=
min
(
kv_cache_config
.
num_blocks
for
kv_cache_config
in
kv_cache_configs
)
for
kv_cache_config
in
kv_cache_configs
:
kv_cache_config
.
num_blocks
=
min_num_blocks
return
kv_cache_configs
vllm/v1/engine/core.py
View file @
93a00d7d
...
...
@@ -21,7 +21,8 @@ from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value
)
from
vllm.utils
import
(
get_exception_traceback
,
resolve_obj_by_qualname
,
zmq_socket_ctx
)
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_configs
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
unify_kv_cache_configs
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
...
...
@@ -120,15 +121,27 @@ class EngineCore:
# memory can be allocated for kv cache.
available_gpu_memory
=
self
.
model_executor
.
determine_available_memory
()
assert
len
(
kv_cache_specs
)
==
len
(
available_gpu_memory
)
# Get the kv cache tensor size
kv_cache_configs
=
get_kv_cache_configs
(
vllm_config
,
kv_cache_specs
,
available_gpu_memory
)
num_gpu_blocks_set
=
set
(
config
.
num_blocks
for
config
in
kv_cache_configs
)
assert
len
(
num_gpu_blocks_set
)
==
1
,
(
f
"num_gpu_blocks need to be the same across workers, "
f
"but they are different:
{
num_gpu_blocks_set
}
"
)
num_gpu_blocks
=
num_gpu_blocks_set
.
pop
()
kv_cache_configs
=
[
get_kv_cache_config
(
vllm_config
,
kv_cache_spec_one_worker
,
available_gpu_memory_one_worker
)
for
kv_cache_spec_one_worker
,
available_gpu_memory_one_worker
in
zip
(
kv_cache_specs
,
available_gpu_memory
)
]
# Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
unify_kv_cache_configs
(
kv_cache_configs
)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to get the number of blocks.
assert
all
([
cfg
.
num_blocks
==
kv_cache_configs
[
0
].
num_blocks
for
cfg
in
kv_cache_configs
])
num_gpu_blocks
=
kv_cache_configs
[
0
].
num_blocks
num_cpu_blocks
=
0
# Initialize kv cache and warmup the execution
...
...
vllm/v1/executor/abstract.py
View file @
93a00d7d
...
...
@@ -62,14 +62,11 @@ class Executor(ExecutorBase):
args
=
(
kv_cache_configs
,
))
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
def
determine_available_memory
(
self
)
->
int
:
# in bytes
def
determine_available_memory
(
self
)
->
list
[
int
]
:
# in bytes
output
=
self
.
collective_rpc
(
"determine_available_memory"
)
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return
min
(
output
)
return
output
def
get_kv_cache_specs
(
self
)
->
list
[
KVCacheSpec
]:
def
get_kv_cache_specs
(
self
)
->
list
[
dict
[
str
,
KVCacheSpec
]
]
:
output
=
self
.
collective_rpc
(
"get_kv_cache_spec"
)
return
output
...
...
@@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class
ExecutorWithExternalLauncher
(
ExecutorWithExternalLauncherV0
,
Executor
):
def
determine_available_memory
(
self
)
->
int
:
# in bytes
def
determine_available_memory
(
self
)
->
list
[
int
]
:
# in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory
=
super
().
determine_available_memory
()
...
...
@@ -103,4 +100,4 @@ class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
cpu_group
=
get_world_group
().
cpu_group
memory_tensor
=
torch
.
tensor
([
memory
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
dist
.
all_reduce
(
memory_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
return
memory_tensor
.
item
()
return
[
memory_tensor
.
item
()
]
vllm/v1/kv_cache_interface.py
View file @
93a00d7d
...
...
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
@
dataclass
class
KVCacheSpec
Base
:
class
KVCacheSpec
:
"""
A base class for specifying the KV cache format of one layer.
"""
...
...
@@ -55,7 +55,7 @@ class KVCacheSpecBase:
@
dataclass
class
FullAttentionSpec
(
KVCacheSpec
Base
):
class
FullAttentionSpec
(
KVCacheSpec
):
num_kv_heads
:
int
head_size
:
int
dtype
:
torch
.
dtype
...
...
@@ -76,9 +76,6 @@ class FullAttentionSpec(KVCacheSpecBase):
return
cdiv
(
num_tokens
,
self
.
block_size
)
*
self
.
page_size_bytes
KVCacheSpec
=
dict
[
str
,
KVCacheSpecBase
]
@
dataclass
class
KVCacheTensor
:
"""
...
...
@@ -89,6 +86,18 @@ class KVCacheTensor:
size
:
int
# The size of KV cache Tensor in bytes
@
dataclass
class
KVCacheGroupSpec
:
"""
Represents a group of model layers that share the same KV cache block table.
These layers are regarded as one layer in the KV cache manager.
"""
# The names of model layers in this group
layer_names
:
list
[
str
]
# The KV cache spec of this manager layer
kv_cache_spec
:
KVCacheSpec
@
dataclass
class
KVCacheConfig
:
"""
...
...
@@ -99,17 +108,24 @@ class KVCacheConfig:
"""layer_name -> how to initialize KV cache for that layer"""
tensors
:
dict
[
str
,
KVCacheTensor
]
"""
A list of kv-cache groups. Each group includes a set of layers with
the same kv-cache spec, and the total page_size of layers inside a group
is same across all groups (as the KVCacheManager only supports allocating
pages of the same size). For example:
1. A model only uses full attention: one group with all layers in the model.
2. (not implemented yet) A model with the same number of full attention
layers and sliding window attention layers: two groups, one for full
attention layers and one for sliding window attention layers.
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
"""
groups
:
list
[
list
[
str
]]
"""the KVCacheSpec of the model"""
kv_cache_spec
:
KVCacheSpec
The kv cache groups of the model.
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 groups, each of which
contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers.
2. (WIP) A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 groups, each of which represents 10 layers in the model.
"""
kv_cache_groups
:
list
[
KVCacheGroupSpec
]
vllm/v1/worker/gpu_model_runner.py
View file @
93a00d7d
...
...
@@ -1510,34 +1510,46 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if
len
(
kv_cache_config
.
groups
)
>
1
:
if
len
(
kv_cache_config
.
kv_cache_
groups
)
>
1
:
raise
NotImplementedError
(
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
layer_name
,
layer_spec
in
kv_cache_config
.
kv_cache_spec
.
items
():
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
:
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
for
layer_name
in
kv_cache_group
.
layer_names
:
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
assert
tensor_config
.
size
%
layer_spec
.
page_size_bytes
==
0
num_blocks
=
tensor_config
.
size
//
layer_spec
.
page_size_bytes
if
isinstance
(
layer_spec
,
FullAttentionSpec
):
assert
tensor_config
.
size
%
kv_cache_spec
.
page_size_bytes
==
0
num_blocks
=
tensor_config
.
size
//
kv_cache_spec
.
page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert
num_blocks
>=
kv_cache_config
.
num_blocks
if
isinstance
(
kv_cache_spec
,
FullAttentionSpec
):
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
layer
_spec
.
block_size
,
layer_spec
.
num_kv_heads
,
layer
_spec
.
head_size
)
dtype
=
layer
_spec
.
dtype
num_blocks
,
kv_cache
_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache
_spec
.
head_size
)
dtype
=
kv_cache
_spec
.
dtype
kv_caches
[
layer_name
]
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
else
:
raise
NotImplementedError
# TODO: add new branches when introducing more types of
# KV cache specs.
raise
ValueError
(
"Unknown KV cache spec type."
)
bind_kv_cache
(
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
...
...
@@ -1549,7 +1561,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
kv_cache_spec
:
KVCacheSpec
=
{}
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
if
isinstance
(
attn_module
,
FusedMoE
):
continue
...
...
vllm/v1/worker/gpu_worker.py
View file @
93a00d7d
...
...
@@ -185,7 +185,7 @@ class Worker(WorkerBase):
return
int
(
available_kv_cache_memory
)
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
return
self
.
model_runner
.
get_kv_cache_spec
()
def
initialize_from_config
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
93a00d7d
...
...
@@ -309,7 +309,7 @@ class TPUModelRunner:
assert
self
.
model
is
not
None
return
self
.
model
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
...
...
@@ -320,7 +320,7 @@ class TPUModelRunner:
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
kv_cache_spec
:
KVCacheSpec
=
{}
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
...
...
@@ -837,22 +837,24 @@ class TPUModelRunner:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if
len
(
kv_cache_config
.
groups
)
>
1
:
if
len
(
kv_cache_config
.
kv_cache_
groups
)
>
1
:
raise
NotImplementedError
(
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
layer_name
,
layer_spec
in
kv_cache_config
.
kv_cache_spec
.
items
():
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
:
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
for
layer_name
in
kv_cache_group
.
layer_names
:
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
assert
tensor_config
.
size
%
layer
_spec
.
page_size_bytes
==
0
num_blocks
=
tensor_config
.
size
//
layer
_spec
.
page_size_bytes
if
isinstance
(
layer
_spec
,
FullAttentionSpec
):
assert
tensor_config
.
size
%
kv_cache
_spec
.
page_size_bytes
==
0
num_blocks
=
tensor_config
.
size
//
kv_cache
_spec
.
page_size_bytes
if
isinstance
(
kv_cache
_spec
,
FullAttentionSpec
):
kv_cache_shape
=
PallasAttentionBackend
.
get_kv_cache_shape
(
num_blocks
,
layer
_spec
.
block_size
,
layer_spec
.
num_kv_heads
,
layer
_spec
.
head_size
)
dtype
=
layer
_spec
.
dtype
num_blocks
,
kv_cache
_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache
_spec
.
head_size
)
dtype
=
kv_cache
_spec
.
dtype
tpu_k_cache
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
dtype
,
...
...
vllm/v1/worker/tpu_worker.py
View file @
93a00d7d
...
...
@@ -189,7 +189,7 @@ class TPUWorker:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
return
self
.
model_runner
.
get_kv_cache_spec
()
def
initialize_from_config
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
vllm/v1/worker/worker_base.py
View file @
93a00d7d
...
...
@@ -51,7 +51,7 @@ class WorkerBase(WorkerBaseV0):
self
.
device
:
Optional
[
torch
.
device
]
=
None
self
.
model_runner
:
Optional
[
nn
.
Module
]
=
None
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
"""Get specifications for KV cache implementation."""
raise
NotImplementedError
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment