Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
85431bd9
Unverified
Commit
85431bd9
authored
Jul 15, 2025
by
Chengji Yao
Committed by
GitHub
Jul 16, 2025
Browse files
[TPU] fix kv_cache_update kernel block size choosing logic (#21007)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
c11013db
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
3 deletions
+51
-3
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+48
-1
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+3
-2
No files found.
vllm/v1/attention/backends/pallas.py
View file @
85431bd9
...
@@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
...
@@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
return
kv_cache
return
kv_cache
# We can move this function to a common utils file if it's also useful for other
# hardware.
def
dtype_bits
(
dtype
:
torch
.
dtype
):
if
dtype
.
is_floating_point
:
try
:
return
torch
.
finfo
(
dtype
).
bits
except
TypeError
:
pass
elif
dtype
.
is_complex
:
if
dtype
is
torch
.
complex32
:
return
32
elif
dtype
is
torch
.
complex64
:
return
64
elif
dtype
is
torch
.
complex128
:
return
128
else
:
try
:
return
torch
.
iinfo
(
dtype
).
bits
# torch.iinfo cannot support int4, int2, bits8...
except
TypeError
:
pass
str_dtype
=
str
(
dtype
)
# support torch.int4, torch.int5, torch.uint5...
if
str_dtype
.
startswith
(
"torch.int"
)
or
str_dtype
.
startswith
(
"torch.uint"
):
return
int
(
str_dtype
[
-
1
])
raise
TypeError
(
f
"Getting the bit width of
{
dtype
}
is not supported"
)
def
get_dtype_packing
(
dtype
):
bits
=
dtype_bits
(
dtype
)
if
32
%
bits
!=
0
:
raise
ValueError
(
f
"The bit width must be divisible by 32, but got bits=
{
bits
}
, "
"dtype={dtype}"
)
return
32
//
bits
def
get_page_size_bytes
(
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
def
get_page_size_bytes
(
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
kv_cache_dtype
:
torch
.
dtype
)
->
int
:
kv_cache_dtype
:
torch
.
dtype
)
->
int
:
"""Returns the size in bytes of one page of the KV cache."""
"""Returns the size in bytes of one page of the KV cache."""
return
block_size
*
num_kv_heads
*
head_size
*
kv_cache_dtype
.
itemsize
padded_head_size
=
cdiv
(
head_size
,
TPU_HEAD_SIZE_ALIGNMENT
)
*
TPU_HEAD_SIZE_ALIGNMENT
num_combined_kv_heads
=
num_kv_heads
*
2
# NOTE: for the implicit padding in XLA
packing
=
get_dtype_packing
(
kv_cache_dtype
)
num_combined_kv_heads
=
cdiv
(
num_combined_kv_heads
,
packing
)
*
packing
kv_cache_dtype_bits
=
dtype_bits
(
kv_cache_dtype
)
return
(
block_size
*
num_combined_kv_heads
*
padded_head_size
*
kv_cache_dtype_bits
//
8
)
vllm/v1/worker/tpu_model_runner.py
View file @
85431bd9
...
@@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
...
@@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
out of scalar registers. Thus this function will limit the number of
out of scalar registers. Thus this function will limit the number of
slices to 64.
slices to 64.
"""
"""
# Conservative VMEM usage limit: 32 MiB
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
vmem_limit
=
32
*
1024
*
1024
# calculate num_slices_per_block based on 16MB in case any register spills.
vmem_limit
=
16
*
1024
*
1024
num_slices_per_block
=
vmem_limit
//
page_size_bytes
num_slices_per_block
=
vmem_limit
//
page_size_bytes
assert
num_slices_per_block
>
0
,
"Number of slices should be positive"
assert
num_slices_per_block
>
0
,
"Number of slices should be positive"
num_slices_per_block
=
prev_power_of_2
(
num_slices_per_block
)
num_slices_per_block
=
prev_power_of_2
(
num_slices_per_block
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment