Unverified Commit 9434a0e5 authored by Johnsonms's avatar Johnsonms Committed by GitHub
Browse files

[Refact] Remove hardcoded KV cache dimension in MLATokenToKVPool (#12502)

parent 20315697
......@@ -1303,9 +1303,11 @@ class MLATokenToKVPool(KVCache):
self.qk_rope_head_dim = qk_rope_head_dim
self.use_nsa = use_nsa
self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
# TODO do not hardcode
assert not (
self.nsa_kv_cache_store_fp8 and override_kv_cache_dim is None
), "override_kv_cache_dim must be provided when using NSA with FP8 kv cache storage"
self.kv_cache_dim = (
656
override_kv_cache_dim
if self.use_nsa and self.nsa_kv_cache_store_fp8
else (kv_lora_rank + qk_rope_head_dim)
)
......@@ -1577,6 +1579,18 @@ class NSATokenToKVPool(MLATokenToKVPool):
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
assert (
kv_lora_rank % self.quant_block_size == 0
), f"kv_lora_rank {kv_lora_rank} must be multiple of quant_block_size {self.quant_block_size}"
# Calculate override_kv_cache_dim for FP8 storage:
# kv_lora_rank + scale storage (kv_lora_rank // quant_block_size * 4 bytes) + rope dimension storage
override_dim = (
kv_lora_rank
+ kv_lora_rank // self.quant_block_size * 4
+ qk_rope_head_dim * dtype.itemsize
)
super().__init__(
size,
page_size,
......@@ -1589,6 +1603,7 @@ class NSATokenToKVPool(MLATokenToKVPool):
start_layer,
end_layer,
use_nsa=True,
override_kv_cache_dim=override_dim,
)
# self.index_k_dtype = torch.float8_e4m3fn
# self.index_k_scale_dtype = torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment