Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
04e1642e
Unverified
Commit
04e1642e
authored
Jun 26, 2025
by
Chengji Yao
Committed by
GitHub
Jun 26, 2025
Browse files
[TPU] add kv cache update kernel (#19928)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
b69781f1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
342 additions
and
38 deletions
+342
-38
.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
+2
-0
tests/v1/tpu/test_kv_cache_update_kernel.py
tests/v1/tpu/test_kv_cache_update_kernel.py
+71
-0
tests/v1/tpu/test_pallas.py
tests/v1/tpu/test_pallas.py
+2
-1
vllm/attention/ops/pallas_kv_cache_update.py
vllm/attention/ops/pallas_kv_cache_update.py
+117
-0
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+50
-5
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+100
-32
No files found.
.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
View file @
04e1642e
...
...
@@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then
...
...
tests/v1/tpu/test_kv_cache_update_kernel.py
0 → 100644
View file @
04e1642e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
pytest
import
torch
import
torch_xla
import
vllm.v1.attention.backends.pallas
# noqa: F401
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This is a test for TPU only"
)
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
32
,
33
])
@
pytest
.
mark
.
parametrize
(
"combined_kv_head_num"
,
[
2
,
16
])
@
pytest
.
mark
.
parametrize
(
"head_dim"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"num_slices_per_block"
,
[
4
,
8
])
def
test_kv_cache_update_kernel
(
page_size
:
int
,
combined_kv_head_num
:
int
,
head_dim
:
int
,
num_slices_per_block
:
int
):
page_num
=
1000
padded_num_tokens
=
128
kv_cache_cpu
=
torch
.
zeros
(
(
page_num
*
page_size
,
combined_kv_head_num
,
head_dim
),
dtype
=
torch
.
bfloat16
,
device
=
"cpu"
)
kv_cache_xla
=
kv_cache_cpu
.
to
(
torch_xla
.
device
())
new_kv_cpu
=
torch
.
randn
(
(
padded_num_tokens
,
combined_kv_head_num
,
head_dim
),
dtype
=
torch
.
bfloat16
,
device
=
"cpu"
)
new_kv_xla
=
new_kv_cpu
.
to
(
torch_xla
.
device
())
slice_lens
=
np
.
array
([
7
,
page_size
,
page_size
,
1
,
1
,
1
,
9
],
dtype
=
np
.
int32
)
kv_cache_start_indices
=
np
.
array
([
page_size
*
2
-
7
,
page_size
*
2
,
page_size
*
3
,
page_size
*
4
+
6
,
page_size
*
5
+
7
,
page_size
*
6
+
8
,
page_size
*
15
+
3
],
dtype
=
np
.
int32
)
new_kv_cache_indices
=
np
.
concatenate
(
[
np
.
array
([
0
],
dtype
=
np
.
int32
),
np
.
cumsum
(
slice_lens
[:
-
1
])])
slot_mapping
=
np
.
stack
(
[
kv_cache_start_indices
,
new_kv_cache_indices
,
slice_lens
],
axis
=
1
)
padded_size
=
(
slot_mapping
.
shape
[
0
]
+
num_slices_per_block
-
1
)
//
num_slices_per_block
*
num_slices_per_block
slot_mapping
=
np
.
pad
(
slot_mapping
,
[[
0
,
padded_size
-
slot_mapping
.
shape
[
0
]],
[
0
,
0
]],
constant_values
=
0
)
slot_mapping
=
np
.
transpose
(
slot_mapping
)
slot_mapping_cpu
=
torch
.
tensor
(
slot_mapping
,
device
=
"cpu"
,
dtype
=
torch
.
int32
)
slot_mapping_xla
=
slot_mapping_cpu
.
to
(
torch_xla
.
device
())
torch_xla
.
sync
()
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
kv_cache_xla
,
True
)
new_kv_cache_xla
=
torch
.
ops
.
xla
.
kv_cache_update_op
(
new_kv_xla
,
slot_mapping_xla
,
kv_cache_xla
,
page_size
,
num_slices_per_block
)
kv_cache_xla
.
copy_
(
new_kv_cache_xla
)
torch_xla
.
sync
()
for
ni
,
ci
,
sl
in
zip
(
new_kv_cache_indices
,
kv_cache_start_indices
,
slice_lens
):
kv_cache_cpu
[
ci
:
ci
+
sl
,
:,
:]
=
new_kv_cpu
[
ni
:
ni
+
sl
,
:,
:]
assert
torch
.
allclose
(
kv_cache_xla
.
cpu
(),
kv_cache_cpu
,
atol
=
1e-4
,
rtol
=
1e-4
)
tests/v1/tpu/test_pallas.py
View file @
04e1642e
...
...
@@ -47,7 +47,7 @@ def test_ragged_paged_attention():
key
=
torch
.
zeros
(
num_tokens
,
num_kv_heads
*
head_size
)
value
=
torch
.
zeros
(
num_tokens
,
num_kv_heads
*
head_size
)
kv_cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
num_kv_heads
*
2
,
head_size
)
slot_mapping
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int64
)
slot_mapping
=
torch
.
zeros
(
(
3
,
num_tokens
)
,
dtype
=
torch
.
int64
)
max_num_reqs
=
8
max_num_blocks_per_req
=
8
block_tables
=
torch
.
zeros
((
max_num_reqs
,
max_num_blocks_per_req
),
...
...
@@ -65,6 +65,7 @@ def test_ragged_paged_attention():
context_lens
=
context_lens
,
query_start_loc
=
query_start_loc
,
num_seqs
=
num_seqs
,
num_slices_per_kv_cache_update_block
=
8
,
)
with
patch
(
"torch.ops.xla.ragged_paged_attention"
...
...
vllm/attention/ops/pallas_kv_cache_update.py
0 → 100644
View file @
04e1642e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
jax
from
jax.experimental
import
pallas
as
pl
from
jax.experimental.pallas
import
tpu
as
pltpu
def
_kv_cache_update_kernel
(
# Prefetch
slices_ref
,
# [3, num_slices], list of (kv_cache_start, new_kv_start,
# slice_len)
# Input
new_kv_hbm_ref
,
# [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref
,
# [total_num_pages * page_size, num_combined_kv_heads,
# head_dim]
# Output
_
,
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
# Scratch
scratch
,
# [num_slices_per_block, page_size, num_combined_kv_heads,
# head_dim]
sem
,
):
async_copies
=
[]
block_idx
=
pl
.
program_id
(
0
)
num_slices_per_block
=
scratch
.
shape
[
0
]
# Copy from new_kv_hbm_ref to scratch
for
i
in
range
(
num_slices_per_block
):
offset_i
=
i
+
block_idx
*
num_slices_per_block
new_kv_start
=
slices_ref
[
1
,
offset_i
]
length
=
slices_ref
[
2
,
offset_i
]
async_copy
=
pltpu
.
make_async_copy
(
new_kv_hbm_ref
.
at
[
pl
.
ds
(
new_kv_start
,
length
),
...],
scratch
.
at
[
i
,
pl
.
ds
(
0
,
length
),
...],
sem
,
)
async_copy
.
start
()
async_copies
.
append
(
async_copy
)
for
async_copy
in
async_copies
:
async_copy
.
wait
()
# Copy from scratch to kv_cache_hbm_ref
async_copies
.
clear
()
for
i
in
range
(
num_slices_per_block
):
offset_i
=
i
+
block_idx
*
num_slices_per_block
kv_cache_start
=
slices_ref
[
0
,
offset_i
]
length
=
slices_ref
[
2
,
offset_i
]
async_copy
=
pltpu
.
make_async_copy
(
scratch
.
at
[
i
,
pl
.
ds
(
0
,
length
),
...],
kv_cache_hbm_ref
.
at
[
pl
.
ds
(
kv_cache_start
,
length
),
...],
sem
,
)
async_copy
.
start
()
async_copies
.
append
(
async_copy
)
for
async_copy
in
async_copies
:
async_copy
.
wait
()
@
functools
.
partial
(
jax
.
jit
,
static_argnames
=
[
"page_size"
,
"num_slices_per_block"
],
)
def
kv_cache_update
(
new_kv
:
jax
.
Array
,
# [total_num_token, num_combined_kv_heads, head_dim]
slices
:
jax
.
Array
,
# [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache
:
jax
.
Array
,
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
*
,
page_size
:
int
=
32
,
num_slices_per_block
:
int
=
8
,
):
assert
slices
.
shape
[
1
]
%
num_slices_per_block
==
0
_
,
num_combined_kv_heads
,
head_dim
=
new_kv
.
shape
assert
kv_cache
.
shape
[
1
]
==
num_combined_kv_heads
assert
kv_cache
.
shape
[
2
]
==
head_dim
assert
head_dim
%
128
==
0
# TODO: Add dynamic check to make sure that the all the slice lengths are
# smaller or equal to page_size
in_specs
=
[
pl
.
BlockSpec
(
memory_space
=
pltpu
.
TPUMemorySpace
.
ANY
),
pl
.
BlockSpec
(
memory_space
=
pltpu
.
TPUMemorySpace
.
ANY
),
]
out_specs
=
[
pl
.
BlockSpec
(
memory_space
=
pltpu
.
TPUMemorySpace
.
ANY
)]
out_shape
=
[
jax
.
ShapeDtypeStruct
(
kv_cache
.
shape
,
dtype
=
kv_cache
.
dtype
)]
scalar_prefetches
=
[
slices
]
scratch
=
pltpu
.
VMEM
(
(
num_slices_per_block
,
page_size
,
num_combined_kv_heads
,
head_dim
),
new_kv
.
dtype
,
)
scratch_shapes
=
[
scratch
,
pltpu
.
SemaphoreType
.
DMA
,
]
kernel
=
pl
.
pallas_call
(
_kv_cache_update_kernel
,
grid_spec
=
pltpu
.
PrefetchScalarGridSpec
(
num_scalar_prefetch
=
len
(
scalar_prefetches
),
in_specs
=
in_specs
,
out_specs
=
out_specs
,
grid
=
(
slices
.
shape
[
1
]
//
num_slices_per_block
,
),
scratch_shapes
=
scratch_shapes
,
),
out_shape
=
out_shape
,
input_output_aliases
=
{
len
(
scalar_prefetches
)
+
1
:
0
},
)
return
kernel
(
*
scalar_prefetches
,
new_kv
,
kv_cache
)[
0
]
vllm/v1/attention/backends/pallas.py
View file @
04e1642e
...
...
@@ -5,8 +5,12 @@ from dataclasses import dataclass
from
typing
import
Any
,
Optional
import
torch
# Required to register custom ops.
import
torch_xla.core.xla_builder
as
xb
import
torch_xla.experimental.custom_kernel
# noqa: F401
# Required to register custom ops.
from
torch.library
import
impl
from
torch_xla._internal.jax_workarounds
import
requires_jax
from
torch_xla.experimental.custom_kernel
import
XLA_LIB
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionType
)
...
...
@@ -107,6 +111,7 @@ class PallasMetadata:
context_lens
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
num_seqs
:
torch
.
Tensor
num_slices_per_kv_cache_update_block
:
int
class
PallasAttentionBackendImpl
(
AttentionImpl
):
...
...
@@ -212,7 +217,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping
=
attn_metadata
.
slot_mapping
write_to_kv_cache
(
key
,
value
,
kv_cache
,
slot_mapping
)
write_to_kv_cache
(
key
,
value
,
kv_cache
,
slot_mapping
,
attn_metadata
.
num_slices_per_kv_cache_update_block
)
output
=
torch
.
ops
.
xla
.
ragged_paged_attention
(
query
,
...
...
@@ -244,6 +251,7 @@ def write_to_kv_cache(
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
num_slices_per_kv_cache_update_block
:
int
,
)
->
None
:
""" Write the key and values to the KV cache.
...
...
@@ -251,9 +259,9 @@ def write_to_kv_cache(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
"""
_
,
_
,
num_combined_kv_heads
,
head_size
=
kv_cache
.
shape
_
,
page_size
,
num_combined_kv_heads
,
head_size
=
kv_cache
.
shape
head_size
=
cdiv
(
head_size
,
TPU_HEAD_SIZE_ALIGNMENT
)
*
TPU_HEAD_SIZE_ALIGNMENT
kv
=
torch
.
cat
([
key
,
value
],
axis
=-
1
).
reshape
(
-
1
,
num_combined_kv_heads
,
...
...
@@ -262,4 +270,41 @@ def write_to_kv_cache(
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
kv_cache
,
True
)
kv_cache
=
kv_cache
.
flatten
(
0
,
1
)
kv_cache
.
index_copy_
(
0
,
slot_mapping
,
kv
)
new_kv_cache
=
torch
.
ops
.
xla
.
kv_cache_update_op
(
kv
,
slot_mapping
,
kv_cache
,
page_size
,
num_slices_per_kv_cache_update_block
)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache
.
copy_
(
new_kv_cache
)
@
requires_jax
def
kv_cache_update_op_impl
(
kv
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
page_size
:
int
,
num_slices_per_block
:
int
):
from
vllm.attention.ops.pallas_kv_cache_update
import
kv_cache_update
new_kv_cache
=
xb
.
call_jax
(
kv_cache_update
,
(
kv
,
slot_mapping
,
kv_cache
),
{
"page_size"
:
page_size
,
"num_slices_per_block"
:
num_slices_per_block
})
return
new_kv_cache
XLA_LIB
.
define
(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
"int page_size, int num_slices_per_block) -> Tensor"
,
)
@
impl
(
XLA_LIB
,
"kv_cache_update_op"
,
"XLA"
)
def
kv_cache_update_op_xla
(
kv
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
page_size
:
int
,
num_slices_per_block
:
int
)
->
torch
.
Tensor
:
new_kv_cache
=
kv_cache_update_op_impl
(
kv
,
slot_mapping
,
kv_cache
,
page_size
,
num_slices_per_block
)
return
new_kv_cache
@
impl
(
XLA_LIB
,
"kv_cache_update_op"
,
"CompositeExplicitAutograd"
)
def
kv_cache_update_op_non_xla
(
kv
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
page_size
:
int
,
num_slices_per_block
:
int
)
->
torch
.
Tensor
:
return
kv_cache
vllm/v1/worker/tpu_model_runner.py
View file @
04e1642e
...
...
@@ -53,12 +53,11 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID
=
1_000_000_000
INVALID_TOKEN_ID
=
-
1
# Smallest output size
MIN_NUM_SEQS
=
8
# Block size used for kv cache updating kernel
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
=
8
#########################################################
...
...
@@ -526,6 +525,69 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return
kv_cache_spec
def
_get_slot_mapping_metadata
(
self
,
num_reqs
,
num_scheduled_tokens_per_req
):
"""
Computes metadata for mapping slots to blocks in the key-value (KV)
cache for a batch of requests.
This function determines, for each request in the batch, how the
scheduled tokens are distributed across memory blocks, and generates
metadata needed to map slices of tokens to their corresponding positions
in the KV cache.
Args:
num_reqs (int): Number of requests in the current batch.
num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
to be scheduled for each request.
Returns:
np.ndarray: A 2D array of shape (total_block_len, 3), where each row
contains:
- kv_cache_start_index (int): The starting index in the KV cache
for the corresponding slice.
- new_kv_start_index (int): The starting index in the new KV
cache for the corresponding slice.
- slice_len (int): The length of the slice.
"""
slices_start
=
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
slices_end
=
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
\
num_scheduled_tokens_per_req
local_block_start_idx
=
slices_start
//
self
.
block_size
local_block_end_idx
=
(
slices_end
-
1
)
//
self
.
block_size
no_repeat_req_indices
=
self
.
arange_np
[:
num_reqs
]
global_block_start_idx
=
(
no_repeat_req_indices
*
self
.
max_num_blocks_per_req
+
local_block_start_idx
)
block_lens
=
local_block_end_idx
-
local_block_start_idx
+
1
global_block_start_idx
=
np
.
repeat
(
global_block_start_idx
,
block_lens
)
slice_arange
=
np
.
concatenate
([
self
.
arange_np
[:
n
]
for
n
in
block_lens
])
global_block_indices
=
global_block_start_idx
+
slice_arange
block_table_cpu
=
self
.
input_batch
.
block_table
[
0
].
get_cpu_tensor
()
block_numbers
=
block_table_cpu
.
flatten
()[
global_block_indices
].
numpy
()
total_block_len
=
np
.
sum
(
block_lens
)
slot_mapping_slices
=
np
.
repeat
(
np
.
array
([[
0
,
self
.
block_size
]],
dtype
=
np
.
int32
),
total_block_len
,
axis
=
0
)
cu_block_lens
=
np
.
zeros
(
len
(
block_lens
)
+
1
,
dtype
=
np
.
int32
)
np
.
cumsum
(
block_lens
,
out
=
cu_block_lens
[
1
:])
for
req_idx
in
range
(
num_reqs
):
slot_mapping_slices
[
cu_block_lens
[
req_idx
]][
0
]
=
slices_start
[
req_idx
]
%
self
.
block_size
slot_mapping_slices
[
cu_block_lens
[
req_idx
+
1
]
-
1
][
1
]
=
(
slices_end
[
req_idx
]
-
1
)
%
self
.
block_size
+
1
slice_lens
=
slot_mapping_slices
[:,
1
]
-
slot_mapping_slices
[:,
0
]
cu_slices_lens
=
np
.
zeros
(
len
(
slice_lens
)
+
1
,
dtype
=
np
.
int32
)
np
.
cumsum
(
slice_lens
,
out
=
cu_slices_lens
[
1
:])
kv_cache_start_indices
=
slot_mapping_slices
[:,
0
]
+
\
(
block_numbers
*
self
.
block_size
)
new_kv_start_indices
=
cu_slices_lens
[:
-
1
]
slot_mapping_metadata
=
np
.
stack
(
[
kv_cache_start_indices
,
new_kv_start_indices
,
slice_lens
],
axis
=
1
)
return
slot_mapping_metadata
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
start_index
:
int
):
assert
scheduler_output
.
total_num_scheduled_tokens
>
0
...
...
@@ -603,26 +665,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
torch
.
from_numpy
(
token_indices
),
out
=
self
.
input_ids_cpu
[:
total_num_scheduled_tokens
])
# Calculate the slot mapping.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# because M (max_model_len) is not necessarily divisible by block_size.
# req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
block_table_indices
=
(
req_indices
*
self
.
max_num_blocks_per_req
+
positions_np
//
self
.
block_size
)
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_table_cpu
=
self
.
input_batch
.
block_table
[
0
].
get_cpu_tensor
()
block_numbers
=
block_table_cpu
.
flatten
()[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
self
.
block_size
np
.
add
(
block_numbers
*
self
.
block_size
,
block_offsets
,
out
=
self
.
input_batch
.
block_table
[
0
].
slot_mapping_np
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
self
.
query_start_loc_np
[
0
]
=
0
np
.
cumsum
(
num_scheduled_tokens_per_req
,
...
...
@@ -645,12 +687,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
position_ids
=
self
.
positions_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
self
.
device
)
self
.
input_batch
.
block_table
[
0
].
slot_mapping_cpu
[
total_num_scheduled_tokens
:]
=
_PAD_SLOT_ID
slot_mapping
=
(
self
.
input_batch
.
block_table
[
0
].
slot_mapping_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
self
.
device
))
if
use_max_model_len
:
block_tables
=
self
.
block_table_cpu
[:
self
.
num_reqs_max_model_len
,
:
self
.
max_num_blocks_per_req
]
...
...
@@ -675,6 +711,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
device
)
block_tables
=
block_tables
.
to
(
self
.
device
)
slot_mapping_metadata
=
self
.
_get_slot_mapping_metadata
(
num_reqs
,
num_scheduled_tokens_per_req
)
padded_num_slices
=
_get_padded_num_kv_cache_update_slices
(
padded_total_num_scheduled_tokens
,
self
.
max_num_reqs
,
self
.
block_size
)
slot_mapping_metadata
=
np
.
pad
(
slot_mapping_metadata
,
[[
0
,
padded_num_slices
-
len
(
slot_mapping_metadata
)],
[
0
,
0
]],
constant_values
=
0
)
slot_mapping_metadata
=
np
.
transpose
(
slot_mapping_metadata
)
slot_mapping_metadata
=
torch
.
tensor
(
slot_mapping_metadata
,
device
=
self
.
device
)
if
self
.
lora_config
is
not
None
:
# We need to respect padding when activating LoRA adapters
padded_num_scheduled_tokens_per_req
=
np
.
copy
(
...
...
@@ -687,13 +736,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
padded_num_scheduled_tokens_per_req
)
attn_metadata
=
PallasMetadata
(
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
_metadata
,
block_tables
=
block_tables
,
context_lens
=
seq_lens
,
query_start_loc
=
query_start_loc
,
num_seqs
=
torch
.
tensor
([
num_reqs
],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
num_slices_per_kv_cache_update_block
=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
...
...
@@ -1119,8 +1170,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
actual_num_reqs
=
min
(
num_tokens
,
num_reqs
)
position_ids
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
)
slot_mapping
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int64
).
to
(
self
.
device
)
padded_num_slices
=
_get_padded_num_kv_cache_update_slices
(
num_tokens
,
self
.
max_num_reqs
,
self
.
block_size
)
slot_mapping
=
torch
.
zeros
((
3
,
padded_num_slices
),
dtype
=
torch
.
int32
).
to
(
self
.
device
)
block_tables
=
torch
.
zeros
((
num_reqs
,
num_blocks
),
dtype
=
torch
.
int32
).
to
(
self
.
device
)
query_lens
=
[
1
]
*
num_reqs
...
...
@@ -1138,6 +1191,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
context_lens
=
context_lens
,
query_start_loc
=
query_start_loc
,
num_seqs
=
num_seqs
,
num_slices_per_kv_cache_update_block
=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
,
)
if
self
.
is_multimodal_model
:
...
...
@@ -1742,6 +1797,19 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return
paddings
[
index
]
def
_get_padded_num_kv_cache_update_slices
(
num_tokens
:
int
,
max_num_reqs
:
int
,
page_size
:
int
)
->
int
:
"""Calculates the padded number of KV cache update slices to avoid
recompilation."""
padded_num_slices
=
2
*
max_num_reqs
+
num_tokens
//
page_size
padded_num_slices
=
min
(
padded_num_slices
,
num_tokens
)
padded_num_slices
=
(
padded_num_slices
+
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
-
1
)
//
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
*
\
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
return
padded_num_slices
def
replace_set_lora
(
model
):
def
_tpu_set_lora
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment