Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7da296be
Unverified
Commit
7da296be
authored
Jul 01, 2025
by
Chengji Yao
Committed by
GitHub
Jul 02, 2025
Browse files
[TPU] kv cache update kernel supports dynamic grid (#20235)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
b205e846
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
42 additions
and
17 deletions
+42
-17
tests/v1/tpu/test_kv_cache_update_kernel.py
tests/v1/tpu/test_kv_cache_update_kernel.py
+6
-2
vllm/attention/ops/pallas_kv_cache_update.py
vllm/attention/ops/pallas_kv_cache_update.py
+6
-3
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+22
-12
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+8
-0
No files found.
tests/v1/tpu/test_kv_cache_update_kernel.py
View file @
7da296be
...
...
@@ -32,6 +32,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
new_kv_xla
=
new_kv_cpu
.
to
(
torch_xla
.
device
())
slice_lens
=
np
.
array
([
7
,
page_size
,
page_size
,
1
,
1
,
1
,
9
],
dtype
=
np
.
int32
)
num_kv_update_slices
=
len
(
slice_lens
)
kv_cache_start_indices
=
np
.
array
([
page_size
*
2
-
7
,
page_size
*
2
,
page_size
*
3
,
page_size
*
4
+
6
,
page_size
*
5
+
7
,
page_size
*
6
+
8
,
page_size
*
15
+
3
...
...
@@ -52,12 +53,15 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
device
=
"cpu"
,
dtype
=
torch
.
int32
)
slot_mapping_xla
=
slot_mapping_cpu
.
to
(
torch_xla
.
device
())
num_kv_update_slices_xla
=
torch
.
tensor
([
num_kv_update_slices
],
device
=
torch_xla
.
device
(),
dtype
=
torch
.
int32
)
torch_xla
.
sync
()
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
kv_cache_xla
,
True
)
new_kv_cache_xla
=
torch
.
ops
.
xla
.
kv_cache_update_op
(
new_kv_xla
,
slot_mapping_xla
,
kv_cache_xla
,
page_size
,
num_slices_per_block
)
new_kv_xla
,
slot_mapping_xla
,
kv_cache_xla
,
num_kv_update_slices_xla
,
page_size
,
num_slices_per_block
)
kv_cache_xla
.
copy_
(
new_kv_cache_xla
)
torch_xla
.
sync
()
...
...
vllm/attention/ops/pallas_kv_cache_update.py
View file @
7da296be
...
...
@@ -7,11 +7,13 @@ import jax
from
jax.experimental
import
pallas
as
pl
from
jax.experimental.pallas
import
tpu
as
pltpu
from
vllm.utils
import
cdiv
def
_kv_cache_update_kernel
(
# Prefetch
slices_ref
,
# [3, num_slices], list of (kv_cache_start,
new_kv_start,
# slice_len)
slices_ref
,
# [3,
padded_
num_slices], list of (kv_cache_start,
#
new_kv_start,
slice_len)
# Input
new_kv_hbm_ref
,
# [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref
,
# [total_num_pages * page_size, num_combined_kv_heads,
...
...
@@ -70,6 +72,7 @@ def kv_cache_update(
Array
,
# [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache
:
jax
.
Array
,
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
num_kv_update_slices
:
jax
.
Array
,
# [1]
*
,
page_size
:
int
=
32
,
num_slices_per_block
:
int
=
8
,
...
...
@@ -107,7 +110,7 @@ def kv_cache_update(
num_scalar_prefetch
=
len
(
scalar_prefetches
),
in_specs
=
in_specs
,
out_specs
=
out_specs
,
grid
=
(
slices
.
shape
[
1
]
//
num_slices_per_block
,
),
grid
=
(
cdiv
(
num_kv_update_slices
[
0
],
num_slices_per_block
)
,
),
scratch_shapes
=
scratch_shapes
,
),
out_shape
=
out_shape
,
...
...
vllm/v1/attention/backends/pallas.py
View file @
7da296be
...
...
@@ -111,6 +111,7 @@ class PallasMetadata:
context_lens
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
num_seqs
:
torch
.
Tensor
num_kv_update_slices
:
torch
.
Tensor
num_slices_per_kv_cache_update_block
:
int
...
...
@@ -219,7 +220,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
slot_mapping
=
attn_metadata
.
slot_mapping
write_to_kv_cache
(
key
,
value
,
kv_cache
,
slot_mapping
,
attn_metadata
.
num_slices_per_kv_cache_update_block
)
attn_metadata
.
num_slices_per_kv_cache_update_block
,
attn_metadata
.
num_kv_update_slices
)
output
=
torch
.
ops
.
xla
.
ragged_paged_attention
(
query
,
...
...
@@ -252,6 +254,7 @@ def write_to_kv_cache(
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
num_slices_per_kv_cache_update_block
:
int
,
num_kv_update_slices
:
torch
.
Tensor
,
)
->
None
:
""" Write the key and values to the KV cache.
...
...
@@ -271,7 +274,7 @@ def write_to_kv_cache(
kv_cache
=
kv_cache
.
flatten
(
0
,
1
)
new_kv_cache
=
torch
.
ops
.
xla
.
kv_cache_update_op
(
kv
,
slot_mapping
,
kv_cache
,
page_size
,
kv
,
slot_mapping
,
kv_cache
,
num_kv_update_slices
,
page_size
,
num_slices_per_kv_cache_update_block
)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache
.
copy_
(
new_kv_cache
)
...
...
@@ -279,10 +282,12 @@ def write_to_kv_cache(
@
requires_jax
def
kv_cache_update_op_impl
(
kv
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
page_size
:
int
,
kv_cache
:
torch
.
Tensor
,
num_kv_update_slices
:
torch
.
Tensor
,
page_size
:
int
,
num_slices_per_block
:
int
):
from
vllm.attention.ops.pallas_kv_cache_update
import
kv_cache_update
new_kv_cache
=
xb
.
call_jax
(
kv_cache_update
,
(
kv
,
slot_mapping
,
kv_cache
),
{
new_kv_cache
=
xb
.
call_jax
(
kv_cache_update
,
(
kv
,
slot_mapping
,
kv_cache
,
num_kv_update_slices
),
{
"page_size"
:
page_size
,
"num_slices_per_block"
:
num_slices_per_block
})
...
...
@@ -290,21 +295,26 @@ def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
XLA_LIB
.
define
(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
"int page_size, int num_slices_per_block) -> Tensor"
,
)
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache,"
\
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)"
\
"-> Tensor"
,
)
@
impl
(
XLA_LIB
,
"kv_cache_update_op"
,
"XLA"
)
def
kv_cache_update_op_xla
(
kv
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
page_size
:
int
,
kv_cache
:
torch
.
Tensor
,
num_kv_update_slices
:
torch
.
Tensor
,
page_size
:
int
,
num_slices_per_block
:
int
)
->
torch
.
Tensor
:
new_kv_cache
=
kv_cache_update_op_impl
(
kv
,
slot_mapping
,
kv_cache
,
page_size
,
num_slices_per_block
)
num_kv_update_slices
,
page_size
,
num_slices_per_block
)
return
new_kv_cache
@
impl
(
XLA_LIB
,
"kv_cache_update_op"
,
"CompositeExplicitAutograd"
)
def
kv_cache_update_op_non_xla
(
kv
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
page_size
:
int
,
kv_cache
:
torch
.
Tensor
,
num_kv_update_slices
:
torch
.
Tensor
,
page_size
:
int
,
num_slices_per_block
:
int
)
->
torch
.
Tensor
:
return
kv_cache
vllm/v1/worker/tpu_model_runner.py
View file @
7da296be
...
...
@@ -713,8 +713,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
device
)
block_tables
=
block_tables
.
to
(
self
.
device
)
# Calculate the slot mapping
slot_mapping_metadata
=
self
.
_get_slot_mapping_metadata
(
num_reqs
,
num_scheduled_tokens_per_req
)
num_kv_update_slices
=
slot_mapping_metadata
.
shape
[
0
]
padded_num_slices
=
_get_padded_num_kv_cache_update_slices
(
padded_total_num_scheduled_tokens
,
self
.
max_num_reqs
,
self
.
block_size
)
...
...
@@ -745,6 +747,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
num_seqs
=
torch
.
tensor
([
num_reqs
],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
num_kv_update_slices
=
torch
.
tensor
([
num_kv_update_slices
],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
num_slices_per_kv_cache_update_block
=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
,
)
...
...
@@ -1174,6 +1179,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
dtype
=
torch
.
int32
).
to
(
self
.
device
)
padded_num_slices
=
_get_padded_num_kv_cache_update_slices
(
num_tokens
,
self
.
max_num_reqs
,
self
.
block_size
)
num_kv_update_slices
=
torch
.
tensor
([
padded_num_slices
],
dtype
=
torch
.
int32
).
to
(
self
.
device
)
slot_mapping
=
torch
.
zeros
((
3
,
padded_num_slices
),
dtype
=
torch
.
int32
).
to
(
self
.
device
)
block_tables
=
torch
.
zeros
((
num_reqs
,
num_blocks
),
...
...
@@ -1193,6 +1200,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
context_lens
=
context_lens
,
query_start_loc
=
query_start_loc
,
num_seqs
=
num_seqs
,
num_kv_update_slices
=
num_kv_update_slices
,
num_slices_per_kv_cache_update_block
=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment