Unverified Commit a4b424c6 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[DeepSeek-V3.2] Include indexer kv cache when estimating kv cache size (#11309)

parent a0557642
......@@ -1177,6 +1177,8 @@ class MLATokenToKVPool(KVCache):
dtype=torch.uint64,
device=self.device,
)
if not use_nsa:
# NSA will allocate indexer KV cache later and then log the total size
self._finalize_allocation_log(size)
def get_kv_size_bytes(self):
......@@ -1298,6 +1300,9 @@ class MLATokenToKVPool(KVCache):
class NSATokenToKVPool(MLATokenToKVPool):
quant_block_size = 128
index_k_with_scale_buffer_dtype = torch.uint8
def __init__(
self,
size: int,
......@@ -1331,8 +1336,6 @@ class NSATokenToKVPool(MLATokenToKVPool):
# num head == 1 and head dim == 128 for index_k in NSA
assert index_head_dim == 128
self.quant_block_size = 128
assert self.page_size == 64
self.index_k_with_scale_buffer = [
torch.zeros(
......@@ -1347,11 +1350,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
self.page_size
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
),
dtype=torch.uint8,
dtype=self.index_k_with_scale_buffer_dtype,
device=device,
)
for _ in range(layer_num)
]
self._finalize_allocation_log(size)
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
if self.layer_transfer_counter is not None:
......@@ -1393,6 +1397,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
)
def get_kv_size_bytes(self):
kv_size_bytes = super().get_kv_size_bytes()
for index_k_cache in self.index_k_with_scale_buffer:
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
return kv_size_bytes
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
def __init__(
......
......@@ -1280,6 +1280,17 @@ class ModelRunner:
* num_layers
* torch._utils._element_size(self.kv_cache_dtype)
)
# Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
if is_deepseek_nsa(self.model_config.hf_config):
index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
indexer_size_per_token = (
index_head_dim
+ index_head_dim // NSATokenToKVPool.quant_block_size * 4
)
element_size = torch._utils._element_size(
NSATokenToKVPool.index_k_with_scale_buffer_dtype
)
cell_size += indexer_size_per_token * num_layers * element_size
else:
cell_size = (
self.model_config.get_num_kv_heads(get_attention_tp_size())
......
......@@ -863,9 +863,6 @@ class ServerArgs:
self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.")
self.mem_fraction_static = 0.8
logger.warning("Setting mem fraction static to 0.8 for DeepSeek NSA.")
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment