Unverified Commit c586b556 authored by Yifei Teng's avatar Yifei Teng Committed by GitHub
Browse files

[TPU] Optimize kv cache update kernel (#20415)


Signed-off-by: default avatarYifei Teng <tengyifei88@gmail.com>
parent 33d56000
......@@ -947,6 +947,13 @@ def next_power_of_2(n) -> int:
return 1 << (n - 1).bit_length()
def prev_power_of_2(n: int) -> int:
"""The previous power of 2 (inclusive)"""
if n <= 0:
return 0
return 1 << (n.bit_length() - 1)
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
......
......@@ -324,3 +324,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype) -> int:
"""Returns the size in bytes of one page of the KV cache."""
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize
......@@ -31,9 +31,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available)
is_pin_memory_available, prev_power_of_2)
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
PallasMetadata,
get_page_size_bytes)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
......@@ -56,8 +57,6 @@ logger = init_logger(__name__)
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
# Block size used for kv cache updating kernel
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
#########################################################
......@@ -139,7 +138,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
model_dtype = self.dtype
if isinstance(model_dtype, str):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
else:
self.kv_cache_dtype = model_dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
......@@ -192,6 +195,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
self._num_slices_per_kv_cache_update_block = \
_get_num_slices_per_kv_cache_update_block(get_page_size_bytes(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
kv_cache_dtype=self.kv_cache_dtype,
))
# Lazy initialization
self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
......@@ -719,7 +730,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
num_kv_update_slices = slot_mapping_metadata.shape[0]
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs,
self.block_size)
self.block_size, self._num_slices_per_kv_cache_update_block)
slot_mapping_metadata = np.pad(
slot_mapping_metadata,
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
......@@ -750,8 +761,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
num_kv_update_slices=torch.tensor([num_kv_update_slices],
dtype=torch.int32,
device=self.device),
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
num_slices_per_kv_cache_update_block=self.
_num_slices_per_kv_cache_update_block,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
......@@ -1197,7 +1208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
position_ids = torch.zeros(num_tokens,
dtype=torch.int32).to(self.device)
padded_num_slices = _get_padded_num_kv_cache_update_slices(
num_tokens, self.max_num_reqs, self.block_size)
num_tokens, self.max_num_reqs, self.block_size,
self._num_slices_per_kv_cache_update_block)
num_kv_update_slices = torch.tensor([padded_num_slices],
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros((3, padded_num_slices),
......@@ -1220,8 +1232,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
query_start_loc=query_start_loc,
num_seqs=num_seqs,
num_kv_update_slices=num_kv_update_slices,
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
num_slices_per_kv_cache_update_block=self.
_num_slices_per_kv_cache_update_block,
)
if self.is_multimodal_model:
......@@ -1826,19 +1838,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return paddings[index]
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
page_size: int) -> int:
def _get_padded_num_kv_cache_update_slices(
num_tokens: int, max_num_reqs: int, page_size: int,
num_slices_per_kv_cache_update_block: int) -> int:
"""Calculates the padded number of KV cache update slices to avoid
recompilation."""
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = min(padded_num_slices, num_tokens)
padded_num_slices = (
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
padded_num_slices + num_slices_per_kv_cache_update_block - 1
) // num_slices_per_kv_cache_update_block * \
num_slices_per_kv_cache_update_block
return padded_num_slices
def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
"""Find the optimum number of slices to copy per Pallas program instance.
Increasing the number of slices copied in one instance of the kernel program
will increase HBM bandwidth utilization via more in-flight DMAs.
However, it will also use more VMEM, and experimentally, we observed
performance regression at 128 slices on v6e, likely due to running
out of scalar registers. Thus this function will limit the number of
slices to 64.
"""
# Conservative VMEM usage limit: 32 MiB
vmem_limit = 32 * 1024 * 1024
num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block)
if num_slices_per_block > 64:
num_slices_per_block = 64
return num_slices_per_block
def replace_set_lora(model):
def _tpu_set_lora(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment