Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c586b556
Unverified
Commit
c586b556
authored
Jul 15, 2025
by
Yifei Teng
Committed by
GitHub
Jul 15, 2025
Browse files
[TPU] Optimize kv cache update kernel (#20415)
Signed-off-by:
Yifei Teng
<
tengyifei88@gmail.com
>
parent
33d56000
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
16 deletions
+63
-16
vllm/utils/__init__.py
vllm/utils/__init__.py
+7
-0
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+6
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+50
-16
No files found.
vllm/utils/__init__.py
View file @
c586b556
...
@@ -947,6 +947,13 @@ def next_power_of_2(n) -> int:
...
@@ -947,6 +947,13 @@ def next_power_of_2(n) -> int:
return
1
<<
(
n
-
1
).
bit_length
()
return
1
<<
(
n
-
1
).
bit_length
()
def
prev_power_of_2
(
n
:
int
)
->
int
:
"""The previous power of 2 (inclusive)"""
if
n
<=
0
:
return
0
return
1
<<
(
n
.
bit_length
()
-
1
)
def
round_up
(
x
:
int
,
y
:
int
)
->
int
:
def
round_up
(
x
:
int
,
y
:
int
)
->
int
:
return
((
x
+
y
-
1
)
//
y
)
*
y
return
((
x
+
y
-
1
)
//
y
)
*
y
...
...
vllm/v1/attention/backends/pallas.py
View file @
c586b556
...
@@ -324,3 +324,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
...
@@ -324,3 +324,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
page_size
:
int
,
page_size
:
int
,
num_slices_per_block
:
int
)
->
torch
.
Tensor
:
num_slices_per_block
:
int
)
->
torch
.
Tensor
:
return
kv_cache
return
kv_cache
def
get_page_size_bytes
(
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
kv_cache_dtype
:
torch
.
dtype
)
->
int
:
"""Returns the size in bytes of one page of the KV cache."""
return
block_size
*
num_kv_heads
*
head_size
*
kv_cache_dtype
.
itemsize
vllm/v1/worker/tpu_model_runner.py
View file @
c586b556
...
@@ -31,9 +31,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
...
@@ -31,9 +31,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockType
,
cdiv
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
is_pin_memory_available
,
prev_power_of_2
)
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
PallasMetadata
)
PallasMetadata
,
get_page_size_bytes
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
KVCacheConfig
,
KVCacheSpec
,
...
@@ -56,8 +57,6 @@ logger = init_logger(__name__)
...
@@ -56,8 +57,6 @@ logger = init_logger(__name__)
INVALID_TOKEN_ID
=
-
1
INVALID_TOKEN_ID
=
-
1
# Smallest output size
# Smallest output size
MIN_NUM_SEQS
=
8
MIN_NUM_SEQS
=
8
# Block size used for kv cache updating kernel
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
=
8
#########################################################
#########################################################
...
@@ -139,7 +138,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -139,7 +138,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
self
.
dtype
=
self
.
model_config
.
dtype
if
cache_config
.
cache_dtype
==
"auto"
:
if
cache_config
.
cache_dtype
==
"auto"
:
self
.
kv_cache_dtype
=
self
.
dtype
model_dtype
=
self
.
dtype
if
isinstance
(
model_dtype
,
str
):
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
model_dtype
]
else
:
self
.
kv_cache_dtype
=
model_dtype
else
:
else
:
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
cache_config
.
cache_dtype
]
...
@@ -192,6 +195,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -192,6 +195,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
max_num_encoder_input_tokens
=
encoder_compute_budget
self
.
max_num_encoder_input_tokens
=
encoder_compute_budget
self
.
encoder_cache_size
=
encoder_cache_size
self
.
encoder_cache_size
=
encoder_cache_size
self
.
_num_slices_per_kv_cache_update_block
=
\
_get_num_slices_per_kv_cache_update_block
(
get_page_size_bytes
(
block_size
=
self
.
block_size
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
))
# Lazy initialization
# Lazy initialization
self
.
model
:
nn
.
Module
# Set after load_model
self
.
model
:
nn
.
Module
# Set after load_model
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
...
@@ -719,7 +730,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -719,7 +730,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
num_kv_update_slices
=
slot_mapping_metadata
.
shape
[
0
]
num_kv_update_slices
=
slot_mapping_metadata
.
shape
[
0
]
padded_num_slices
=
_get_padded_num_kv_cache_update_slices
(
padded_num_slices
=
_get_padded_num_kv_cache_update_slices
(
padded_total_num_scheduled_tokens
,
self
.
max_num_reqs
,
padded_total_num_scheduled_tokens
,
self
.
max_num_reqs
,
self
.
block_size
)
self
.
block_size
,
self
.
_num_slices_per_kv_cache_update_block
)
slot_mapping_metadata
=
np
.
pad
(
slot_mapping_metadata
=
np
.
pad
(
slot_mapping_metadata
,
slot_mapping_metadata
,
[[
0
,
padded_num_slices
-
len
(
slot_mapping_metadata
)],
[
0
,
0
]],
[[
0
,
padded_num_slices
-
len
(
slot_mapping_metadata
)],
[
0
,
0
]],
...
@@ -750,8 +761,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -750,8 +761,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
num_kv_update_slices
=
torch
.
tensor
([
num_kv_update_slices
],
num_kv_update_slices
=
torch
.
tensor
([
num_kv_update_slices
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
device
=
self
.
device
),
num_slices_per_kv_cache_update_block
=
num_slices_per_kv_cache_update_block
=
self
.
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
,
_num_slices_per_kv_cache_update_block
,
)
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# request in the batch. While we should not sample any token from this
...
@@ -1197,7 +1208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1197,7 +1208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
position_ids
=
torch
.
zeros
(
num_tokens
,
position_ids
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
)
dtype
=
torch
.
int32
).
to
(
self
.
device
)
padded_num_slices
=
_get_padded_num_kv_cache_update_slices
(
padded_num_slices
=
_get_padded_num_kv_cache_update_slices
(
num_tokens
,
self
.
max_num_reqs
,
self
.
block_size
)
num_tokens
,
self
.
max_num_reqs
,
self
.
block_size
,
self
.
_num_slices_per_kv_cache_update_block
)
num_kv_update_slices
=
torch
.
tensor
([
padded_num_slices
],
num_kv_update_slices
=
torch
.
tensor
([
padded_num_slices
],
dtype
=
torch
.
int32
).
to
(
self
.
device
)
dtype
=
torch
.
int32
).
to
(
self
.
device
)
slot_mapping
=
torch
.
zeros
((
3
,
padded_num_slices
),
slot_mapping
=
torch
.
zeros
((
3
,
padded_num_slices
),
...
@@ -1220,8 +1232,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1220,8 +1232,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
num_seqs
=
num_seqs
,
num_seqs
=
num_seqs
,
num_kv_update_slices
=
num_kv_update_slices
,
num_kv_update_slices
=
num_kv_update_slices
,
num_slices_per_kv_cache_update_block
=
num_slices_per_kv_cache_update_block
=
self
.
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
,
_num_slices_per_kv_cache_update_block
,
)
)
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
...
@@ -1826,19 +1838,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
...
@@ -1826,19 +1838,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return
paddings
[
index
]
return
paddings
[
index
]
def
_get_padded_num_kv_cache_update_slices
(
num_tokens
:
int
,
max_num_reqs
:
int
,
def
_get_padded_num_kv_cache_update_slices
(
page_size
:
int
)
->
int
:
num_tokens
:
int
,
max_num_reqs
:
int
,
page_size
:
int
,
num_slices_per_kv_cache_update_block
:
int
)
->
int
:
"""Calculates the padded number of KV cache update slices to avoid
"""Calculates the padded number of KV cache update slices to avoid
recompilation."""
recompilation."""
padded_num_slices
=
2
*
max_num_reqs
+
num_tokens
//
page_size
padded_num_slices
=
2
*
max_num_reqs
+
num_tokens
//
page_size
padded_num_slices
=
min
(
padded_num_slices
,
num_tokens
)
padded_num_slices
=
min
(
padded_num_slices
,
num_tokens
)
padded_num_slices
=
(
padded_num_slices
=
(
padded_num_slices
+
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
-
1
padded_num_slices
+
num_slices_per_kv_cache_update_block
-
1
)
//
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
*
\
)
//
num_slices_per_kv_cache_update_block
*
\
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
num_slices_per_kv_cache_update_block
return
padded_num_slices
return
padded_num_slices
def
_get_num_slices_per_kv_cache_update_block
(
page_size_bytes
:
int
)
->
int
:
"""Find the optimum number of slices to copy per Pallas program instance.
Increasing the number of slices copied in one instance of the kernel program
will increase HBM bandwidth utilization via more in-flight DMAs.
However, it will also use more VMEM, and experimentally, we observed
performance regression at 128 slices on v6e, likely due to running
out of scalar registers. Thus this function will limit the number of
slices to 64.
"""
# Conservative VMEM usage limit: 32 MiB
vmem_limit
=
32
*
1024
*
1024
num_slices_per_block
=
vmem_limit
//
page_size_bytes
assert
num_slices_per_block
>
0
,
"Number of slices should be positive"
num_slices_per_block
=
prev_power_of_2
(
num_slices_per_block
)
if
num_slices_per_block
>
64
:
num_slices_per_block
=
64
return
num_slices_per_block
def
replace_set_lora
(
model
):
def
replace_set_lora
(
model
):
def
_tpu_set_lora
(
def
_tpu_set_lora
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment