Unverified Commit f8a1a2d1 authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[v1] Hybrid Memory Allocator (#17996)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent 3465b87e
......@@ -1365,14 +1365,20 @@ class TPUModelRunner(LoRAModelRunnerMixin):
assert self.block_table_cpu.dtype == self.input_batch.block_table[
0].get_cpu_tensor().dtype
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_sizes = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
assert len(kv_cache_tensor.shared_by) == 1, (
"KV cache tensor shared by multiple layers is not supported in "
"TPU.")
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
tensor_size = kv_cache_sizes[layer_name]
assert tensor_size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_size // kv_cache_spec.page_size_bytes # noqa
if isinstance(kv_cache_spec, AttentionSpec):
if self.use_spmd:
num_kv_heads = kv_cache_spec.num_kv_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment