Unverified Commit 9434a0e5 authored by Johnsonms's avatar Johnsonms Committed by GitHub
Browse files

[Refact] Remove hardcoded KV cache dimension in MLATokenToKVPool (#12502)

parent 20315697
...@@ -1303,9 +1303,11 @@ class MLATokenToKVPool(KVCache): ...@@ -1303,9 +1303,11 @@ class MLATokenToKVPool(KVCache):
self.qk_rope_head_dim = qk_rope_head_dim self.qk_rope_head_dim = qk_rope_head_dim
self.use_nsa = use_nsa self.use_nsa = use_nsa
self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
# TODO do not hardcode assert not (
self.nsa_kv_cache_store_fp8 and override_kv_cache_dim is None
), "override_kv_cache_dim must be provided when using NSA with FP8 kv cache storage"
self.kv_cache_dim = ( self.kv_cache_dim = (
656 override_kv_cache_dim
if self.use_nsa and self.nsa_kv_cache_store_fp8 if self.use_nsa and self.nsa_kv_cache_store_fp8
else (kv_lora_rank + qk_rope_head_dim) else (kv_lora_rank + qk_rope_head_dim)
) )
...@@ -1577,6 +1579,18 @@ class NSATokenToKVPool(MLATokenToKVPool): ...@@ -1577,6 +1579,18 @@ class NSATokenToKVPool(MLATokenToKVPool):
start_layer: Optional[int] = None, start_layer: Optional[int] = None,
end_layer: Optional[int] = None, end_layer: Optional[int] = None,
): ):
assert (
kv_lora_rank % self.quant_block_size == 0
), f"kv_lora_rank {kv_lora_rank} must be multiple of quant_block_size {self.quant_block_size}"
# Calculate override_kv_cache_dim for FP8 storage:
# kv_lora_rank + scale storage (kv_lora_rank // quant_block_size * 4 bytes) + rope dimension storage
override_dim = (
kv_lora_rank
+ kv_lora_rank // self.quant_block_size * 4
+ qk_rope_head_dim * dtype.itemsize
)
super().__init__( super().__init__(
size, size,
page_size, page_size,
...@@ -1589,6 +1603,7 @@ class NSATokenToKVPool(MLATokenToKVPool): ...@@ -1589,6 +1603,7 @@ class NSATokenToKVPool(MLATokenToKVPool):
start_layer, start_layer,
end_layer, end_layer,
use_nsa=True, use_nsa=True,
override_kv_cache_dim=override_dim,
) )
# self.index_k_dtype = torch.float8_e4m3fn # self.index_k_dtype = torch.float8_e4m3fn
# self.index_k_scale_dtype = torch.float32 # self.index_k_scale_dtype = torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment