Unverified Commit 85431bd9 authored by Chengji Yao's avatar Chengji Yao Committed by GitHub
Browse files

[TPU] fix kv_cache_update kernel block size choosing logic (#21007)


Signed-off-by: default avatarChengji Yao <chengjiyao@google.com>
parent c11013db
...@@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, ...@@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
return kv_cache return kv_cache
# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):
if dtype.is_floating_point:
try:
return torch.finfo(dtype).bits
except TypeError:
pass
elif dtype.is_complex:
if dtype is torch.complex32:
return 32
elif dtype is torch.complex64:
return 64
elif dtype is torch.complex128:
return 128
else:
try:
return torch.iinfo(dtype).bits
# torch.iinfo cannot support int4, int2, bits8...
except TypeError:
pass
str_dtype = str(dtype)
# support torch.int4, torch.int5, torch.uint5...
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
return int(str_dtype[-1])
raise TypeError(f"Getting the bit width of {dtype} is not supported")
def get_dtype_packing(dtype):
bits = dtype_bits(dtype)
if 32 % bits != 0:
raise ValueError(
f"The bit width must be divisible by 32, but got bits={bits}, "
"dtype={dtype}")
return 32 // bits
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype) -> int: kv_cache_dtype: torch.dtype) -> int:
"""Returns the size in bytes of one page of the KV cache.""" """Returns the size in bytes of one page of the KV cache."""
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize padded_head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
num_combined_kv_heads = num_kv_heads * 2
# NOTE: for the implicit padding in XLA
packing = get_dtype_packing(kv_cache_dtype)
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
return (block_size * num_combined_kv_heads * padded_head_size *
kv_cache_dtype_bits // 8)
...@@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: ...@@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
out of scalar registers. Thus this function will limit the number of out of scalar registers. Thus this function will limit the number of
slices to 64. slices to 64.
""" """
# Conservative VMEM usage limit: 32 MiB # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
vmem_limit = 32 * 1024 * 1024 # calculate num_slices_per_block based on 16MB in case any register spills.
vmem_limit = 16 * 1024 * 1024
num_slices_per_block = vmem_limit // page_size_bytes num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive" assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block) num_slices_per_block = prev_power_of_2(num_slices_per_block)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment