Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2cc20699
Unverified
Commit
2cc20699
authored
Jun 25, 2025
by
Chengji Yao
Committed by
GitHub
Jun 25, 2025
Browse files
[TPU][Bugfix] fix kv cache padding (#20048)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
9f0608fc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
9 deletions
+14
-9
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+1
-7
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+13
-2
No files found.
vllm/v1/attention/backends/pallas.py
View file @
2cc20699
...
...
@@ -48,13 +48,7 @@ class PallasAttentionBackend(AttentionBackend):
)
->
tuple
[
int
,
...]:
padded_head_size
=
cdiv
(
head_size
,
TPU_HEAD_SIZE_ALIGNMENT
)
*
TPU_HEAD_SIZE_ALIGNMENT
num_blocks
=
num_blocks
*
head_size
//
padded_head_size
if
padded_head_size
!=
head_size
:
logger
.
warning_once
(
"head size is padded to %d, and num_blocks is adjusted to %d"
" accordingly"
,
padded_head_size
,
num_blocks
)
head_size
=
padded_head_size
return
(
num_blocks
,
block_size
,
num_kv_heads
*
2
,
head_size
)
return
(
num_blocks
,
block_size
,
num_kv_heads
*
2
,
padded_head_size
)
@
staticmethod
def
swap_blocks
(
...
...
vllm/v1/worker/tpu_worker.py
View file @
2cc20699
...
...
@@ -18,7 +18,8 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
from
vllm.v1.attention.backends.pallas
import
TPU_HEAD_SIZE_ALIGNMENT
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
...
...
@@ -221,7 +222,17 @@ class TPUWorker:
usable_memory_size
=
int
(
total_memory_size
*
self
.
cache_config
.
gpu_memory_utilization
)
tpu_kv_cache_bytes
=
max
(
usable_memory_size
-
profiled
,
0
)
head_size
=
self
.
model_config
.
get_head_size
()
if
head_size
>
0
:
padded_head_size
=
cdiv
(
head_size
,
TPU_HEAD_SIZE_ALIGNMENT
)
*
TPU_HEAD_SIZE_ALIGNMENT
if
padded_head_size
!=
head_size
:
logger
.
warning_once
(
"head size is padded to %d"
,
padded_head_size
)
# We adjust the usable memory size for the KV cache to prevent OOM
# errors, even after padding the head_size.
tpu_kv_cache_bytes
=
(
tpu_kv_cache_bytes
*
head_size
//
padded_head_size
)
return
int
(
tpu_kv_cache_bytes
)
def
execute_model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment