Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7d44c469
Unverified
Commit
7d44c469
authored
Jun 09, 2025
by
Siyuan Liu
Committed by
GitHub
Jun 09, 2025
Browse files
[TPU]Fix KV cache sharing tests (#19371)
parent
31f58be9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
60 deletions
+52
-60
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+52
-60
No files found.
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
7d44c469
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
unittest.mock
as
mock
import
pytest
...
...
@@ -17,24 +16,8 @@ from vllm.v1.worker.tpu_model_runner import (
TPUModelRunner
,
_get_padded_num_reqs_with_upper_limit
,
_get_padded_token_len
,
_get_req_paddings
,
_get_token_paddings
)
# Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher
=
mock
.
patch
.
dict
(
"sys.modules"
,
{
"torch_xla"
:
mock
.
MagicMock
(),
"torch_xla.core.xla_model"
:
mock
.
MagicMock
(),
"torch_xla.runtime"
:
mock
.
MagicMock
(),
})
torch_xla_patcher
.
start
()
# Mock the PallasAttentionBackend
pallas_attention_backend_patcher
=
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend"
,
)
pallas_attention_backend_patcher
.
start
()
@
pytest
.
fixture
def
model_runner
():
# Patchers have already been started at module level.
def
get_vllm_config
():
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
10
,
max_num_batched_tokens
=
512
,
...
...
@@ -60,18 +43,19 @@ def model_runner():
cache_config
=
cache_config
,
scheduler_config
=
scheduler_config
,
)
return
vllm_config
def
get_model_runner
(
vllm_config
):
device
=
"xla:0"
# Mocking TPU device
with
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.torch"
),
\
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.xm"
),
\
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.xr"
):
return
TPUModelRunner
(
vllm_config
,
device
)
return
TPUModelRunner
(
vllm_config
,
device
)
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"session"
)
def
cleanup_patches
():
yield
torch_xla_patcher
.
stop
()
pallas_attention_backend_patcher
.
stop
(
)
@
pytest
.
fixture
def
model_runner
():
# Patchers have already been started at module level.
vllm_config
=
get_vllm_config
()
return
get_model_runner
(
vllm_config
)
def
_schedule_new_request
(
*
req_ids
:
str
)
->
SchedulerOutput
:
...
...
@@ -370,12 +354,14 @@ def test_get_req_paddings():
assert
_get_req_paddings
(
8
,
36
)
==
[
8
,
16
,
32
,
36
]
@
pytest
.
mark
.
skip
(
reason
=
"Test is broken on TPU when it's added."
)
def
test_init_kv_cache_with_kv_sharing_invalid_target_layer_ord
er
(
):
def
test_init_kv_cache_with_kv_sharing_invalid_target_layer_order
(
model_runn
er
):
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
error_msg
=
f
"
{
layer_1
}
must come before the current layer"
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
vllm_config
=
model_runner
.
vllm_config
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
),
\
set_current_vllm_config
(
vllm_config
):
fwd_context
=
{
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
...
...
@@ -399,13 +385,14 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert
fwd_context
is
not
None
@
pytest
.
mark
.
skip
(
reason
=
"Test is broken on TPU when it's added."
)
def
test_init_kv_cache_with_kv_sharing_target_layer_not_exist
():
def
test_init_kv_cache_with_kv_sharing_target_layer_not_exist
(
model_runner
):
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
invalid_layer
=
"model.layers.0.cross_attn.attn"
error_msg
=
f
"
{
invalid_layer
}
is not a valid Attention layer in the model"
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
vllm_config
=
model_runner
.
vllm_config
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
),
\
set_current_vllm_config
(
vllm_config
):
fwd_context
=
{
layer_0
:
Attention
(
...
...
@@ -428,12 +415,13 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert
fwd_context
is
not
None
@
pytest
.
mark
.
skip
(
reason
=
"Test is broken on TPU when it's added."
)
def
test_init_kv_cache_with_kv_sharing_target_same_as_current
():
def
test_init_kv_cache_with_kv_sharing_target_same_as_current
(
model_runner
):
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
error_msg
=
f
"
{
layer_1
}
cannot be the same as the current layer"
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
vllm_config
=
model_runner
.
vllm_config
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
),
\
set_current_vllm_config
(
vllm_config
):
fwd_context
=
{
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
...
...
@@ -457,11 +445,10 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
assert
fwd_context
is
not
None
@
pytest
.
mark
.
skip
(
reason
=
"Test is broken on TPU when it's added."
)
def
test_init_kv_cache_without_kv_sharing
(
model_runner
):
def
test_init_kv_cache_without_kv_sharing
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
vllm_config
=
model_runner
.
vllm_config
vllm_config
=
get_
vllm_config
()
with
set_current_vllm_config
(
vllm_config
):
fwd_context
=
{
layer_0
:
...
...
@@ -482,33 +469,38 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
# suppress var not used error
assert
fwd_context
is
not
None
# Set high context length to test max context length estimation
vllm_config
.
model_config
.
max_model_len
=
3
_000_000
vllm_config
.
model_config
.
max_model_len
=
1
_000_000
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
model_runner
=
get_model_runner
(
vllm_config
)
kv_cache_spec
=
model_runner
.
get_kv_cache_spec
()
assert
len
(
kv_cache_spec
)
==
2
assert
len
(
model_runner
.
shared_kv_cache_layers
)
==
0
available_memory
=
20
*
GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks
=
327680
# 20GB / 32KB / 2 (num layers)
# page size for each layer KV can be calculated as
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
num_expected_blocks
=
20480
# 20GB / 512KB / 2 (num layers)
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
available_memory
)
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
len
(
kv_cache_config
.
tensors
)
==
2
assert
kv_cache_config
.
tensors
[
layer_
0
].
size
==
available_memory
//
2
assert
kv_cache_config
.
tensors
[
layer_
1
].
size
==
available_memory
//
2
assert
len
(
kv_cache_config
.
kv_cache_
tensors
)
==
2
assert
kv_cache_config
.
kv_cache_
tensors
[
0
].
size
==
available_memory
//
2
assert
kv_cache_config
.
kv_cache_
tensors
[
1
].
size
==
available_memory
//
2
max_context_len
=
\
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
# max context len with KV sharing should be 2x as large as without
assert
max_context_len
==
1310720
# max_context_len = available_memory / (page_size / block_size) / num_caches
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
assert
max_context_len
==
655360
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 *
3
2kb)
# this will only allocate 2 block worth of memory (2 *
51
2kb)
kv_cache_config
.
num_blocks
=
1
for
laye
r
in
kv_cache_config
.
tensors
:
kv_cache_
config
.
tensors
[
layer
]
.
size
=
\
kv_cache_spec
[
layer
].
page_size_bytes
for
kv_cache_tenso
r
in
kv_cache_config
.
kv_cache_
tensors
:
kv_cache_
tensor
.
size
=
(
kv_cache_spec
[
kv_cache_tensor
.
shared_by
[
0
]
].
page_size_bytes
)
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
...
...
@@ -524,11 +516,10 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
@
pytest
.
mark
.
skip
(
reason
=
"Test is broken on TPU when it's added."
)
def
test_init_kv_cache_with_kv_sharing_valid
(
model_runner
):
def
test_init_kv_cache_with_kv_sharing_valid
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
vllm_config
=
model_runner
.
vllm_config
vllm_config
=
get_
vllm_config
()
with
set_current_vllm_config
(
vllm_config
):
fwd_context
=
{
layer_0
:
...
...
@@ -552,33 +543,34 @@ def test_init_kv_cache_with_kv_sharing_valid(model_runner):
# Set high context length to test max context length estimation
vllm_config
.
model_config
.
max_model_len
=
3_000_000
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
model_runner
=
get_model_runner
(
vllm_config
)
kv_cache_spec
=
model_runner
.
get_kv_cache_spec
()
assert
len
(
kv_cache_spec
)
==
1
assert
layer_0
in
kv_cache_spec
assert
model_runner
.
shared_kv_cache_layers
[
layer_1
]
==
layer_0
available_memory
=
20
*
GiB_bytes
# page size for layer 0's kv_cache_spec is
3
2KB
# page size for layer 0's kv_cache_spec is
51
2KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks
=
65536
0
# 20GB /
3
2KB
num_expected_blocks
=
2
*
2048
0
# 20GB /
51
2KB
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
available_memory
)
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
len
(
kv_cache_config
.
tensors
)
==
1
assert
len
(
kv_cache_config
.
kv_cache_
tensors
)
==
1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert
kv_cache_config
.
tensors
[
layer_
0
].
size
==
available_memory
assert
kv_cache_config
.
kv_cache_
tensors
[
0
].
size
==
available_memory
max_context_len
=
\
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
# max context len with KV sharing should be 2x as large as without
assert
max_context_len
==
2
*
1310720
assert
max_context_len
==
(
2
*
655360
)
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (
3
2kb)
# this will only allocate 1 block worth of memory (
51
2kb)
kv_cache_config
.
num_blocks
=
1
kv_cache_config
.
tensors
[
layer_
0
].
size
=
\
kv_cache_config
.
kv_cache_
tensors
[
0
].
size
=
\
kv_cache_spec
[
layer_0
].
page_size_bytes
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment