Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32ce3cf7
Unverified
Commit
32ce3cf7
authored
May 29, 2025
by
Nicolò Lucchesi
Committed by
GitHub
May 29, 2025
Browse files
[V1] Allocate kv_cache with stride order for V1 (#18775)
Signed-off-by:
nicklucche
<
nlucches@redhat.com
>
parent
d58f9c7f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
16 deletions
+81
-16
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+58
-13
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+23
-3
No files found.
tests/v1/worker/test_gpu_model_runner.py
View file @
32ce3cf7
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
pytest
import
pytest
from
vllm.attention
import
Attention
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VllmConfig
)
SchedulerConfig
,
VllmConfig
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -13,27 +16,30 @@ from vllm.v1.sample.metadata import SamplingMetadata
...
@@ -13,27 +16,30 @@ from vllm.v1.sample.metadata import SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
BLOCK_SIZE
=
16
NUM_BLOCKS
=
10
def
initialize_kv_cache
(
runner
:
GPUModelRunner
):
def
initialize_kv_cache
(
runner
:
GPUModelRunner
):
"""
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
"""
"""
attn_spec
=
FullAttentionSpec
(
block_size
=
BLOCK_SIZE
,
num_kv_heads
=
runner
.
model_config
.
get_num_kv_heads
(
runner
.
parallel_config
),
head_size
=
runner
.
model_config
.
get_head_size
(),
dtype
=
runner
.
kv_cache_dtype
,
use_mla
=
False
,
)
tensor_size
=
attn_spec
.
page_size_bytes
*
NUM_BLOCKS
kv_cache_config
=
KVCacheConfig
(
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
10
,
num_blocks
=
NUM_BLOCKS
,
tensors
=
{
tensors
=
{
"layer.0"
:
KVCacheTensor
(
size
=
1024
),
"layer.0"
:
KVCacheTensor
(
size
=
tensor_size
),
},
},
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
(
KVCacheGroupSpec
(
layer_names
=
[
"layer.0"
],
kv_cache_spec
=
attn_spec
)
layer_names
=
[
"layer.0"
],
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
16
,
num_kv_heads
=
runner
.
model_config
.
get_num_kv_heads
(
runner
.
parallel_config
),
head_size
=
runner
.
model_config
.
get_head_size
(),
dtype
=
runner
.
kv_cache_dtype
,
use_mla
=
False
,
))
])
])
runner
.
kv_cache_config
=
kv_cache_config
runner
.
kv_cache_config
=
kv_cache_config
runner
.
input_batch
=
InputBatch
(
runner
.
input_batch
=
InputBatch
(
...
@@ -65,7 +71,7 @@ def model_runner():
...
@@ -65,7 +71,7 @@ def model_runner():
seed
=
42
,
seed
=
42
,
)
)
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
block_size
=
16
,
block_size
=
BLOCK_SIZE
,
gpu_memory_utilization
=
0.9
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
cache_dtype
=
"auto"
,
...
@@ -77,6 +83,10 @@ def model_runner():
...
@@ -77,6 +83,10 @@ def model_runner():
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
parallel_config
=
parallel_config
,
parallel_config
=
parallel_config
,
)
)
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
head_size
=
model_config
.
get_head_size
()
vllm_config
.
compilation_config
.
static_forward_context
[
"layer.0"
]
=
Attention
(
num_heads
,
head_size
,
0.1
)
device
=
"cuda"
device
=
"cuda"
runner
=
GPUModelRunner
(
vllm_config
,
device
)
runner
=
GPUModelRunner
(
vllm_config
,
device
)
...
@@ -321,3 +331,38 @@ def test_update_states_request_unscheduled(model_runner):
...
@@ -321,3 +331,38 @@ def test_update_states_request_unscheduled(model_runner):
assert
_is_req_added
(
model_runner
,
req_ids
[
1
])
assert
_is_req_added
(
model_runner
,
req_ids
[
1
])
assert
not
_is_req_scheduled
(
model_runner
,
req_ids
[
1
])
assert
not
_is_req_scheduled
(
model_runner
,
req_ids
[
1
])
def
test_kv_cache_stride_order
(
monkeypatch
,
model_runner
):
# This test checks if GPUModelRunner initializes correctly when an attention
# backend enforces a non-default KV cache stride order.
n_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
parallel_config
)
expected_kv_cache_shape
=
[
2
,
NUM_BLOCKS
,
BLOCK_SIZE
,
n_heads
,
model_runner
.
model_config
.
get_head_size
()
]
# TODO mla test
default_stride
=
list
(
range
(
5
))
# Permutation that gets you back to expected kv shape
rnd_stride
=
tuple
(
random
.
sample
(
default_stride
,
len
(
default_stride
)))
def
rnd_stride_order
():
return
rnd_stride
# Patch the attention backend class and re-trigger the KV cache creation.
for
attn_backend
in
model_runner
.
attn_backends
:
monkeypatch
.
setattr
(
attn_backend
,
"get_kv_cache_stride_order"
,
rnd_stride_order
)
model_runner
.
attn_backends
=
[]
model_runner
.
attn_metadata_builders
=
[]
model_runner
.
initialize_kv_cache
(
model_runner
.
kv_cache_config
)
# Shape is unchanged, but layout may differ
kv_cache_shape
=
model_runner
.
kv_caches
[
0
].
shape
assert
list
(
kv_cache_shape
)
==
expected_kv_cache_shape
if
default_stride
==
rnd_stride
:
assert
all
(
kv
.
is_contiguous
()
for
kv
in
model_runner
.
kv_caches
)
else
:
assert
all
(
not
kv
.
is_contiguous
()
for
kv
in
model_runner
.
kv_caches
)
vllm/v1/worker/gpu_model_runner.py
View file @
32ce3cf7
...
@@ -2033,9 +2033,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2033,9 +2033,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_blocks
,
kv_cache_spec
.
block_size
,
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
dtype
=
kv_cache_spec
.
dtype
kv_caches
[
layer_name
]
=
torch
.
zeros
(
kv_cache_shape
,
try
:
dtype
=
dtype
,
kv_cache_stride_order
=
self
.
attn_backends
[
device
=
self
.
device
)
i
].
get_kv_cache_stride_order
()
assert
len
(
kv_cache_stride_order
)
==
len
(
kv_cache_shape
)
except
(
AttributeError
,
NotImplementedError
):
kv_cache_stride_order
=
tuple
(
range
(
len
(
kv_cache_shape
)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
kv_cache_shape
=
tuple
(
kv_cache_shape
[
i
]
for
i
in
kv_cache_stride_order
)
# Maintain original KV shape view.
inv_order
=
[
kv_cache_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
kv_cache_stride_order
))
]
kv_caches
[
layer_name
]
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
).
permute
(
*
inv_order
)
else
:
else
:
# TODO: add new branches when introducing more types of
# TODO: add new branches when introducing more types of
# KV cache specs.
# KV cache specs.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment