Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6078d5fc
Unverified
Commit
6078d5fc
authored
Aug 22, 2025
by
huangtingwei
Committed by
GitHub
Aug 22, 2025
Browse files
[HiCacheStorage] backup optimization for MLA model (#8865)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
70cf4abc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
20 deletions
+39
-20
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+26
-12
python/sglang/srt/mem_cache/hicache_storage.py
python/sglang/srt/mem_cache/hicache_storage.py
+2
-2
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+3
-2
python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py
...ng/srt/mem_cache/storage/mooncake_store/mooncake_store.py
+8
-4
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
6078d5fc
...
...
@@ -26,6 +26,8 @@ if TYPE_CHECKING:
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.mem_cache.memory_pool_host
import
MLATokenToKVPoolHost
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -238,13 +240,14 @@ class HiCacheController:
self
.
io_backend
=
io_backend
self
.
enable_storage
=
False
self
.
is_mla
=
isinstance
(
self
.
mem_pool_host
,
MLATokenToKVPoolHost
)
# todo: move backend initialization to storage backend module
if
storage_backend
is
not
None
:
self
.
storage_backend_type
=
storage_backend
from
sglang.srt.mem_cache.hicache_storage
import
HiCacheFile
,
get_hash_str
if
storage_backend
==
"file"
:
self
.
storage_backend
=
HiCacheFile
()
self
.
storage_backend
=
HiCacheFile
(
is_mla
=
self
.
is_mla
)
self
.
get_hash_str
=
get_hash_str
elif
storage_backend
==
"nixl"
:
from
sglang.srt.mem_cache.storage.nixl.hicache_nixl
import
HiCacheNixl
...
...
@@ -257,12 +260,11 @@ class HiCacheController:
get_hash_str_mooncake
,
)
self
.
storage_backend
=
MooncakeStore
()
self
.
storage_backend
=
MooncakeStore
(
is_mla
=
self
.
is_mla
)
self
.
get_hash_str
=
get_hash_str_mooncake
self
.
storage_backend
.
register_buffer
(
self
.
mem_pool_host
.
kv_buffer
)
assert
self
.
mem_pool_host
.
layout
==
"page_first"
elif
storage_backend
==
"hf3fs"
:
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs
import
(
HiCacheHF3FS
,
)
...
...
@@ -399,6 +401,15 @@ class HiCacheController:
self
.
prefetch_thread
.
start
()
self
.
backup_thread
.
start
()
@
property
def
backup_skip
(
self
):
return
(
self
.
is_mla
and
get_tensor_model_parallel_rank
()
!=
0
# todo: only support file and mooncake
and
self
.
storage_backend_type
in
[
"file"
,
"mooncake"
]
)
def
write
(
self
,
device_indices
:
torch
.
Tensor
,
...
...
@@ -809,17 +820,20 @@ class HiCacheController:
if
operation
is
None
:
continue
if
self
.
is_mooncake_backend
():
self
.
mooncake_page_backup
(
operation
)
elif
self
.
storage_backend_type
==
"hf3fs"
:
if
self
.
mem_pool_host
.
layout
==
"page_first"
:
self
.
zerocopy_page_backup
(
operation
,
batch_size
=
128
)
elif
self
.
mem_pool_host
.
layout
==
"layer_first"
:
self
.
generic_page_backup
(
operation
,
batch_size
=
128
)
if
not
self
.
backup_skip
:
if
self
.
is_mooncake_backend
():
self
.
mooncake_page_backup
(
operation
)
elif
self
.
storage_backend_type
==
"hf3fs"
:
if
self
.
mem_pool_host
.
layout
==
"page_first"
:
self
.
zerocopy_page_backup
(
operation
,
batch_size
=
128
)
elif
self
.
mem_pool_host
.
layout
==
"layer_first"
:
self
.
generic_page_backup
(
operation
,
batch_size
=
128
)
else
:
self
.
generic_page_backup
(
operation
)
min_completed_tokens
=
operation
.
completed_tokens
else
:
self
.
generic_page_backup
(
operation
)
min_completed_tokens
=
len
(
operation
.
token_ids
)
min_completed_tokens
=
operation
.
completed_tokens
if
self
.
tp_world_size
>
1
:
completed_tokens_tensor
=
torch
.
tensor
(
min_completed_tokens
,
dtype
=
torch
.
int
...
...
python/sglang/srt/mem_cache/hicache_storage.py
View file @
6078d5fc
...
...
@@ -101,11 +101,11 @@ class HiCacheStorage(ABC):
class
HiCacheFile
(
HiCacheStorage
):
def
__init__
(
self
,
file_path
:
str
=
"/tmp/hicache"
):
def
__init__
(
self
,
file_path
:
str
=
"/tmp/hicache"
,
is_mla
:
bool
=
False
):
self
.
file_path
=
os
.
getenv
(
"SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR"
,
file_path
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_suffix
=
f
"_
{
tp_rank
}
_
{
tp_size
}
"
if
tp_size
>
1
else
""
self
.
tp_suffix
=
f
"_
{
tp_rank
}
_
{
tp_size
}
"
if
tp_size
>
1
and
not
is_mla
else
""
if
not
os
.
path
.
exists
(
self
.
file_path
)
and
tp_rank
==
0
:
os
.
makedirs
(
self
.
file_path
)
logger
.
info
(
f
"Created HiCacheFile storage directory at
{
self
.
file_path
}
"
)
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
6078d5fc
...
...
@@ -7,6 +7,7 @@ from functools import wraps
import
psutil
import
torch
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
MHATokenToKVPool
,
MLATokenToKVPool
from
sglang.srt.utils
import
is_npu
...
...
@@ -487,8 +488,8 @@ class MHATokenToKVPoolHost(HostKVCache):
ptr_list
.
append
(
k_ptr
)
ptr_list
.
append
(
v_ptr
)
key_
=
keys
[
index
//
self
.
page_size
]
key_list
.
append
(
f
"
{
key_
}
_k"
)
key_list
.
append
(
f
"
{
key_
}
_v"
)
key_list
.
append
(
f
"
{
key_
}
_
{
get_tensor_model_parallel_rank
()
}
_
k"
)
key_list
.
append
(
f
"
{
key_
}
_
{
get_tensor_model_parallel_rank
()
}
_
v"
)
element_size
=
(
self
.
layer_num
*
self
.
dtype
.
itemsize
...
...
python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py
View file @
6078d5fc
...
...
@@ -19,14 +19,13 @@ logger = logging.getLogger(__name__)
def
get_hash_str_mooncake
(
token_ids
:
List
[
int
],
prior_hash
:
str
=
None
):
local_rank
=
get_tensor_model_parallel_rank
()
prefix_str
=
""
if
prior_hash
:
prefix_str
=
hashlib
.
sha256
(
prior_hash
.
encode
()).
hexdigest
()
current_token_ids_bytes
=
np
.
array
(
token_ids
).
tobytes
()
current_hash_object
=
hashlib
.
sha256
(
current_token_ids_bytes
)
current_hash_hex
=
current_hash_object
.
hexdigest
()
return
f
"
{
prefix_str
}
_
{
int
(
current_hash_hex
[:
16
],
16
)
}
_
{
local_rank
}
"
return
f
"
{
prefix_str
}
_
{
int
(
current_hash_hex
[:
16
],
16
)
}
"
@
dataclass
...
...
@@ -97,7 +96,7 @@ class MooncakeStoreConfig:
class
MooncakeStore
(
HiCacheStorage
):
def
__init__
(
self
):
def
__init__
(
self
,
is_mla
:
bool
=
False
):
try
:
from
mooncake.store
import
MooncakeDistributedStore
except
ImportError
as
e
:
...
...
@@ -127,6 +126,7 @@ class MooncakeStore(HiCacheStorage):
logger
.
info
(
"Connect to Mooncake store successfully."
)
self
.
warmup
()
logger
.
info
(
"Mooncake store warmup successfully."
)
self
.
is_mla
=
is_mla
except
ValueError
as
e
:
logger
.
error
(
"Configuration loading failed: %s"
,
e
)
...
...
@@ -223,11 +223,15 @@ class MooncakeStore(HiCacheStorage):
def
exists
(
self
,
keys
)
->
bool
|
dict
:
_keys
=
[]
local_rank
=
get_tensor_model_parallel_rank
()
for
key
in
keys
:
if
key
is
None
:
return
None
_keys
.
append
(
f
"
{
key
}
_k"
)
if
self
.
is_mla
:
_keys
.
append
(
f
"
{
key
}
_k"
)
else
:
_keys
.
append
(
f
"
{
key
}
_
{
local_rank
}
_k"
)
result
=
{
k
:
v
for
k
,
v
in
zip
(
keys
,
self
.
store
.
batch_is_exist
(
_keys
))}
return
result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment