Unverified Commit a4b424c6 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[DeepSeek-V3.2] Include indexer kv cache when estimating kv cache size (#11309)

parent a0557642
...@@ -1177,7 +1177,9 @@ class MLATokenToKVPool(KVCache): ...@@ -1177,7 +1177,9 @@ class MLATokenToKVPool(KVCache):
dtype=torch.uint64, dtype=torch.uint64,
device=self.device, device=self.device,
) )
self._finalize_allocation_log(size) if not use_nsa:
# NSA will allocate indexer KV cache later and then log the total size
self._finalize_allocation_log(size)
def get_kv_size_bytes(self): def get_kv_size_bytes(self):
assert hasattr(self, "kv_buffer") assert hasattr(self, "kv_buffer")
...@@ -1298,6 +1300,9 @@ class MLATokenToKVPool(KVCache): ...@@ -1298,6 +1300,9 @@ class MLATokenToKVPool(KVCache):
class NSATokenToKVPool(MLATokenToKVPool): class NSATokenToKVPool(MLATokenToKVPool):
quant_block_size = 128
index_k_with_scale_buffer_dtype = torch.uint8
def __init__( def __init__(
self, self,
size: int, size: int,
...@@ -1331,8 +1336,6 @@ class NSATokenToKVPool(MLATokenToKVPool): ...@@ -1331,8 +1336,6 @@ class NSATokenToKVPool(MLATokenToKVPool):
# num head == 1 and head dim == 128 for index_k in NSA # num head == 1 and head dim == 128 for index_k in NSA
assert index_head_dim == 128 assert index_head_dim == 128
self.quant_block_size = 128
assert self.page_size == 64 assert self.page_size == 64
self.index_k_with_scale_buffer = [ self.index_k_with_scale_buffer = [
torch.zeros( torch.zeros(
...@@ -1347,11 +1350,12 @@ class NSATokenToKVPool(MLATokenToKVPool): ...@@ -1347,11 +1350,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
self.page_size self.page_size
* (index_head_dim + index_head_dim // self.quant_block_size * 4), * (index_head_dim + index_head_dim // self.quant_block_size * 4),
), ),
dtype=torch.uint8, dtype=self.index_k_with_scale_buffer_dtype,
device=device, device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
self._finalize_allocation_log(size)
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
...@@ -1393,6 +1397,12 @@ class NSATokenToKVPool(MLATokenToKVPool): ...@@ -1393,6 +1397,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
) )
def get_kv_size_bytes(self):
kv_size_bytes = super().get_kv_size_bytes()
for index_k_cache in self.index_k_with_scale_buffer:
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
return kv_size_bytes
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
def __init__( def __init__(
......
...@@ -1280,6 +1280,17 @@ class ModelRunner: ...@@ -1280,6 +1280,17 @@ class ModelRunner:
* num_layers * num_layers
* torch._utils._element_size(self.kv_cache_dtype) * torch._utils._element_size(self.kv_cache_dtype)
) )
# Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
if is_deepseek_nsa(self.model_config.hf_config):
index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
indexer_size_per_token = (
index_head_dim
+ index_head_dim // NSATokenToKVPool.quant_block_size * 4
)
element_size = torch._utils._element_size(
NSATokenToKVPool.index_k_with_scale_buffer_dtype
)
cell_size += indexer_size_per_token * num_layers * element_size
else: else:
cell_size = ( cell_size = (
self.model_config.get_num_kv_heads(get_attention_tp_size()) self.model_config.get_num_kv_heads(get_attention_tp_size())
......
...@@ -863,9 +863,6 @@ class ServerArgs: ...@@ -863,9 +863,6 @@ class ServerArgs:
self.page_size = 64 self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.") logger.warning("Setting page size to 64 for DeepSeek NSA.")
self.mem_fraction_static = 0.8
logger.warning("Setting mem fraction static to 0.8 for DeepSeek NSA.")
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently # For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment