Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9434a0e5
Unverified
Commit
9434a0e5
authored
Nov 02, 2025
by
Johnsonms
Committed by
GitHub
Nov 02, 2025
Browse files
[Refact] Remove hardcoded KV cache dimension in MLATokenToKVPool (#12502)
parent
20315697
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
2 deletions
+17
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+17
-2
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
9434a0e5
...
@@ -1303,9 +1303,11 @@ class MLATokenToKVPool(KVCache):
...
@@ -1303,9 +1303,11 @@ class MLATokenToKVPool(KVCache):
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
use_nsa
=
use_nsa
self
.
use_nsa
=
use_nsa
self
.
nsa_kv_cache_store_fp8
=
use_nsa
and
dtype
==
torch
.
float8_e4m3fn
self
.
nsa_kv_cache_store_fp8
=
use_nsa
and
dtype
==
torch
.
float8_e4m3fn
# TODO do not hardcode
assert
not
(
self
.
nsa_kv_cache_store_fp8
and
override_kv_cache_dim
is
None
),
"override_kv_cache_dim must be provided when using NSA with FP8 kv cache storage"
self
.
kv_cache_dim
=
(
self
.
kv_cache_dim
=
(
656
override_kv_cache_dim
if
self
.
use_nsa
and
self
.
nsa_kv_cache_store_fp8
if
self
.
use_nsa
and
self
.
nsa_kv_cache_store_fp8
else
(
kv_lora_rank
+
qk_rope_head_dim
)
else
(
kv_lora_rank
+
qk_rope_head_dim
)
)
)
...
@@ -1577,6 +1579,18 @@ class NSATokenToKVPool(MLATokenToKVPool):
...
@@ -1577,6 +1579,18 @@ class NSATokenToKVPool(MLATokenToKVPool):
start_layer
:
Optional
[
int
]
=
None
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
):
assert
(
kv_lora_rank
%
self
.
quant_block_size
==
0
),
f
"kv_lora_rank
{
kv_lora_rank
}
must be multiple of quant_block_size
{
self
.
quant_block_size
}
"
# Calculate override_kv_cache_dim for FP8 storage:
# kv_lora_rank + scale storage (kv_lora_rank // quant_block_size * 4 bytes) + rope dimension storage
override_dim
=
(
kv_lora_rank
+
kv_lora_rank
//
self
.
quant_block_size
*
4
+
qk_rope_head_dim
*
dtype
.
itemsize
)
super
().
__init__
(
super
().
__init__
(
size
,
size
,
page_size
,
page_size
,
...
@@ -1589,6 +1603,7 @@ class NSATokenToKVPool(MLATokenToKVPool):
...
@@ -1589,6 +1603,7 @@ class NSATokenToKVPool(MLATokenToKVPool):
start_layer
,
start_layer
,
end_layer
,
end_layer
,
use_nsa
=
True
,
use_nsa
=
True
,
override_kv_cache_dim
=
override_dim
,
)
)
# self.index_k_dtype = torch.float8_e4m3fn
# self.index_k_dtype = torch.float8_e4m3fn
# self.index_k_scale_dtype = torch.float32
# self.index_k_scale_dtype = torch.float32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment