Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
85431bd9
Unverified
Commit
85431bd9
authored
Jul 15, 2025
by
Chengji Yao
Committed by
GitHub
Jul 16, 2025
Browse files
[TPU] fix kv_cache_update kernel block size choosing logic (#21007)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
c11013db
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
3 deletions
+51
-3
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+48
-1
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+3
-2
No files found.
vllm/v1/attention/backends/pallas.py
View file @
85431bd9
...
...
@@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
return
kv_cache
# We can move this function to a common utils file if it's also useful for other
# hardware.
def
dtype_bits
(
dtype
:
torch
.
dtype
):
if
dtype
.
is_floating_point
:
try
:
return
torch
.
finfo
(
dtype
).
bits
except
TypeError
:
pass
elif
dtype
.
is_complex
:
if
dtype
is
torch
.
complex32
:
return
32
elif
dtype
is
torch
.
complex64
:
return
64
elif
dtype
is
torch
.
complex128
:
return
128
else
:
try
:
return
torch
.
iinfo
(
dtype
).
bits
# torch.iinfo cannot support int4, int2, bits8...
except
TypeError
:
pass
str_dtype
=
str
(
dtype
)
# support torch.int4, torch.int5, torch.uint5...
if
str_dtype
.
startswith
(
"torch.int"
)
or
str_dtype
.
startswith
(
"torch.uint"
):
return
int
(
str_dtype
[
-
1
])
raise
TypeError
(
f
"Getting the bit width of
{
dtype
}
is not supported"
)
def
get_dtype_packing
(
dtype
):
bits
=
dtype_bits
(
dtype
)
if
32
%
bits
!=
0
:
raise
ValueError
(
f
"The bit width must be divisible by 32, but got bits=
{
bits
}
, "
"dtype={dtype}"
)
return
32
//
bits
def
get_page_size_bytes
(
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
kv_cache_dtype
:
torch
.
dtype
)
->
int
:
"""Returns the size in bytes of one page of the KV cache."""
return
block_size
*
num_kv_heads
*
head_size
*
kv_cache_dtype
.
itemsize
padded_head_size
=
cdiv
(
head_size
,
TPU_HEAD_SIZE_ALIGNMENT
)
*
TPU_HEAD_SIZE_ALIGNMENT
num_combined_kv_heads
=
num_kv_heads
*
2
# NOTE: for the implicit padding in XLA
packing
=
get_dtype_packing
(
kv_cache_dtype
)
num_combined_kv_heads
=
cdiv
(
num_combined_kv_heads
,
packing
)
*
packing
kv_cache_dtype_bits
=
dtype_bits
(
kv_cache_dtype
)
return
(
block_size
*
num_combined_kv_heads
*
padded_head_size
*
kv_cache_dtype_bits
//
8
)
vllm/v1/worker/tpu_model_runner.py
View file @
85431bd9
...
...
@@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
out of scalar registers. Thus this function will limit the number of
slices to 64.
"""
# Conservative VMEM usage limit: 32 MiB
vmem_limit
=
32
*
1024
*
1024
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
# calculate num_slices_per_block based on 16MB in case any register spills.
vmem_limit
=
16
*
1024
*
1024
num_slices_per_block
=
vmem_limit
//
page_size_bytes
assert
num_slices_per_block
>
0
,
"Number of slices should be positive"
num_slices_per_block
=
prev_power_of_2
(
num_slices_per_block
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment