Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d0730487
Unverified
Commit
d0730487
authored
Sep 05, 2025
by
pansicheng
Committed by
GitHub
Sep 04, 2025
Browse files
fix 3fs zerocopy (#9938)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
b32ab070
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
54 deletions
+50
-54
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+28
-27
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+16
-11
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
+6
-16
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
d0730487
...
@@ -324,6 +324,22 @@ class HiCacheController:
...
@@ -324,6 +324,22 @@ class HiCacheController:
group_ranks
,
backend
=
"gloo"
group_ranks
,
backend
=
"gloo"
)
)
# Select the get and set functions
self
.
page_get_func
=
self
.
_generic_page_get
self
.
page_set_func
=
self
.
_generic_page_set
self
.
batch_exists_func
=
self
.
storage_backend
.
batch_exists
self
.
is_3fs_zerocopy
=
(
self
.
storage_backend_type
==
"hf3fs"
and
self
.
mem_pool_host
.
layout
==
"page_first"
)
if
self
.
storage_backend_type
==
"mooncake"
:
self
.
page_get_func
=
self
.
_mooncake_page_get
self
.
page_set_func
=
self
.
_mooncake_page_set
elif
self
.
is_3fs_zerocopy
:
self
.
page_get_func
=
self
.
_3fs_zero_copy_page_get
self
.
page_set_func
=
self
.
_3fs_zero_copy_page_set
self
.
batch_exists_func
=
self
.
_3fs_zero_copy_batch_exists
self
.
load_cache_event
=
load_cache_event
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
...
@@ -617,13 +633,19 @@ class HiCacheController:
...
@@ -617,13 +633,19 @@ class HiCacheController:
for
chunk
in
chunks
:
for
chunk
in
chunks
:
self
.
host_mem_release_queue
.
put
(
chunk
)
self
.
host_mem_release_queue
.
put
(
chunk
)
def
_3fs_zero_copy_batch_exists
(
self
,
batch_hashes
):
_batch_hashes
,
_
,
factor
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
batch_hashes
)
hit_page_num
=
self
.
storage_backend
.
batch_exists
(
_batch_hashes
)
//
factor
return
hit_page_num
def
_3fs_zero_copy_page_get
(
self
,
operation
,
hash_values
,
host_indices
):
def
_3fs_zero_copy_page_get
(
self
,
operation
,
hash_values
,
host_indices
):
hashes
,
dsts
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
hashes
,
dsts
,
factor
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
hash_values
,
host_indices
hash_values
,
host_indices
)
)
page_data
=
self
.
storage_backend
.
batch_get
(
hashes
,
dsts
)
page_data
=
self
.
storage_backend
.
batch_get
(
hashes
,
dsts
)
if
page_data
:
if
page_data
:
operation
.
increment
(
self
.
page_size
*
len
(
hashes
))
inc
=
self
.
page_size
*
len
(
hashes
)
//
factor
operation
.
increment
(
inc
)
else
:
else
:
logger
.
warning
(
logger
.
warning
(
f
"Prefetch operation
{
operation
.
request_id
}
failed to retrieve page
{
hashes
}
."
f
"Prefetch operation
{
operation
.
request_id
}
failed to retrieve page
{
hashes
}
."
...
@@ -670,17 +692,6 @@ class HiCacheController:
...
@@ -670,17 +692,6 @@ class HiCacheController:
break
# Operation terminated by controller
break
# Operation terminated by controller
def
_page_transfer
(
self
,
operation
):
def
_page_transfer
(
self
,
operation
):
# Select the get function and batch size
if
self
.
storage_backend_type
==
"mooncake"
:
get_func
=
self
.
_mooncake_page_get
elif
(
self
.
storage_backend_type
==
"hf3fs"
and
self
.
mem_pool_host
.
layout
==
"page_first"
):
get_func
=
self
.
_3fs_zero_copy_page_get
else
:
get_func
=
self
.
_generic_page_get
# Transfer batch by batch
# Transfer batch by batch
for
i
in
range
(
0
,
len
(
operation
.
hash_value
),
self
.
storage_batch_size
):
for
i
in
range
(
0
,
len
(
operation
.
hash_value
),
self
.
storage_batch_size
):
batch_hashes
=
operation
.
hash_value
[
i
:
i
+
self
.
storage_batch_size
]
batch_hashes
=
operation
.
hash_value
[
i
:
i
+
self
.
storage_batch_size
]
...
@@ -689,7 +700,7 @@ class HiCacheController:
...
@@ -689,7 +700,7 @@ class HiCacheController:
]
]
prev_completed_tokens
=
operation
.
completed_tokens
prev_completed_tokens
=
operation
.
completed_tokens
# Get one batch token, and update the completed_tokens if succeed
# Get one batch token, and update the completed_tokens if succeed
get_func
(
operation
,
batch_hashes
,
batch_host_indices
)
self
.
page_
get_func
(
operation
,
batch_hashes
,
batch_host_indices
)
# Check termination
# Check termination
if
(
if
(
operation
.
completed_tokens
operation
.
completed_tokens
...
@@ -746,7 +757,7 @@ class HiCacheController:
...
@@ -746,7 +757,7 @@ class HiCacheController:
batch_tokens
[
i
:
i
+
self
.
page_size
],
last_hash
batch_tokens
[
i
:
i
+
self
.
page_size
],
last_hash
)
)
batch_hashes
.
append
(
last_hash
)
batch_hashes
.
append
(
last_hash
)
hit_page_num
=
self
.
storage_backend
.
batch_exists
(
batch_hashes
)
hit_page_num
=
self
.
batch_exists
_func
(
batch_hashes
)
hash_value
.
extend
(
batch_hashes
[:
hit_page_num
])
hash_value
.
extend
(
batch_hashes
[:
hit_page_num
])
storage_query_count
+=
hit_page_num
*
self
.
page_size
storage_query_count
+=
hit_page_num
*
self
.
page_size
if
hit_page_num
<
len
(
batch_hashes
):
if
hit_page_num
<
len
(
batch_hashes
):
...
@@ -839,23 +850,13 @@ class HiCacheController:
...
@@ -839,23 +850,13 @@ class HiCacheController:
# zero copy
# zero copy
def
_3fs_zero_copy_page_set
(
self
,
hash_values
,
host_indices
)
->
bool
:
def
_3fs_zero_copy_page_set
(
self
,
hash_values
,
host_indices
)
->
bool
:
hashes
,
dsts
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
hashes
,
dsts
,
_
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
hash_values
,
host_indices
hash_values
,
host_indices
)
)
return
self
.
storage_backend
.
batch_set
(
hashes
,
dsts
)
return
self
.
storage_backend
.
batch_set
(
hashes
,
dsts
)
# Backup batch by batch
# Backup batch by batch
def
_page_backup
(
self
,
operation
):
def
_page_backup
(
self
,
operation
):
# Select the set function and batch size
if
self
.
storage_backend_type
==
"mooncake"
:
backup_set_func
=
self
.
_mooncake_page_set
elif
(
self
.
storage_backend_type
==
"hf3fs"
and
self
.
mem_pool_host
.
layout
==
"page_first"
):
backup_set_func
=
self
.
_3fs_zero_copy_page_set
else
:
backup_set_func
=
self
.
_generic_page_set
# Backup batch by batch
# Backup batch by batch
for
i
in
range
(
0
,
len
(
operation
.
hash_value
),
self
.
storage_batch_size
):
for
i
in
range
(
0
,
len
(
operation
.
hash_value
),
self
.
storage_batch_size
):
batch_hashes
=
operation
.
hash_value
[
i
:
i
+
self
.
storage_batch_size
]
batch_hashes
=
operation
.
hash_value
[
i
:
i
+
self
.
storage_batch_size
]
...
@@ -864,7 +865,7 @@ class HiCacheController:
...
@@ -864,7 +865,7 @@ class HiCacheController:
]
]
# Set one batch token, and record if success.
# Set one batch token, and record if success.
# todo: allow partial success
# todo: allow partial success
success
=
backup
_set_func
(
batch_hashes
,
batch_host_indices
)
success
=
self
.
page
_set_func
(
batch_hashes
,
batch_host_indices
)
if
not
success
:
if
not
success
:
logger
.
warning
(
logger
.
warning
(
f
"Write page to storage:
{
len
(
batch_hashes
)
}
pages failed."
f
"Write page to storage:
{
len
(
batch_hashes
)
}
pages failed."
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
d0730487
...
@@ -500,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -500,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache):
element_size_list
=
[
element_size
]
*
len
(
key_list
)
element_size_list
=
[
element_size
]
*
len
(
key_list
)
return
key_list
,
ptr_list
,
element_size_list
return
key_list
,
ptr_list
,
element_size_list
def
get_buffer_with_hash
(
self
,
keys
,
indices
):
def
get_buffer_with_hash
(
self
,
keys
,
indices
=
None
):
assert
self
.
layout
==
"page_first"
assert
self
.
layout
==
"page_first"
assert
len
(
keys
)
==
(
len
(
indices
)
//
self
.
page_size
)
assert
indices
is
None
or
(
len
(
keys
)
==
(
len
(
indices
)
//
self
.
page_size
)
)
key_list
=
[]
key_list
=
[]
buf_list
=
[]
buf_list
=
[]
for
key
,
i
in
zip
(
keys
,
range
(
0
,
len
(
indices
),
self
.
page_size
)):
for
i
in
range
(
len
(
keys
)):
key
=
keys
[
i
]
key_list
.
append
(
f
"
{
key
}
-k"
)
key_list
.
append
(
f
"
{
key
}
-k"
)
buf_list
.
append
(
self
.
k_buffer
[
i
:
i
+
self
.
page_size
])
key_list
.
append
(
f
"
{
key
}
-v"
)
key_list
.
append
(
f
"
{
key
}
-v"
)
buf_list
.
append
(
self
.
v_buffer
[
i
:
i
+
self
.
page_size
])
if
indices
is
not
None
:
index
=
indices
[
i
*
self
.
page_size
]
buf_list
.
append
(
self
.
k_buffer
[
index
:
index
+
self
.
page_size
])
buf_list
.
append
(
self
.
v_buffer
[
index
:
index
+
self
.
page_size
])
return
key_list
,
buf_list
return
key_list
,
buf_list
,
2
class
MLATokenToKVPoolHost
(
HostKVCache
):
class
MLATokenToKVPoolHost
(
HostKVCache
):
...
@@ -728,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -728,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache):
element_size_list
=
[
element_size
]
*
len
(
key_list
)
element_size_list
=
[
element_size
]
*
len
(
key_list
)
return
key_list
,
ptr_list
,
element_size_list
return
key_list
,
ptr_list
,
element_size_list
def
get_buffer_with_hash
(
self
,
keys
,
indices
):
def
get_buffer_with_hash
(
self
,
keys
,
indices
=
None
):
assert
self
.
layout
==
"page_first"
assert
self
.
layout
==
"page_first"
assert
len
(
keys
)
==
(
len
(
indices
)
//
self
.
page_size
)
assert
indices
is
None
or
(
len
(
keys
)
==
(
len
(
indices
)
//
self
.
page_size
)
)
buf_list
=
[]
buf_list
=
[]
for
i
in
range
(
0
,
len
(
indices
),
self
.
page_size
):
if
indices
is
not
None
:
buf_list
.
append
(
self
.
kv_buffer
[
i
:
i
+
self
.
page_size
])
for
i
in
range
(
len
(
keys
)):
index
=
indices
[
i
*
self
.
page_size
]
buf_list
.
append
(
self
.
kv_buffer
[
index
:
index
+
self
.
page_size
])
return
keys
,
buf_list
return
keys
,
buf_list
,
1
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
View file @
d0730487
...
@@ -415,22 +415,12 @@ class HiCacheHF3FS(HiCacheStorage):
...
@@ -415,22 +415,12 @@ class HiCacheHF3FS(HiCacheStorage):
return
result
[
0
]
if
result
else
False
return
result
[
0
]
if
result
else
False
def
batch_exists
(
self
,
keys
:
List
[
str
])
->
int
:
def
batch_exists
(
self
,
keys
:
List
[
str
])
->
int
:
if
self
.
is_page_first_layout
and
not
self
.
is_mla_model
:
results
=
self
.
metadata_client
.
exists
(
self
.
rank
,
keys
)
query_keys
=
[]
for
i
in
range
(
len
(
keys
)):
# Compatible with page_first layout's key format, Refer to memory_pool_host.py#get_buffer_with_hash
if
not
results
[
i
]:
for
key
in
keys
:
return
i
query_keys
.
append
(
f
"
{
key
}
-k"
)
query_keys
.
append
(
f
"
{
key
}
-v"
)
return
len
(
keys
)
key_multiplier
=
2
else
:
query_keys
=
keys
key_multiplier
=
1
exist_result
=
self
.
metadata_client
.
exists
(
self
.
rank
,
query_keys
)
for
i
in
range
(
len
(
query_keys
)):
if
not
exist_result
[
i
]:
return
i
//
key_multiplier
return
len
(
query_keys
)
//
key_multiplier
def
clear
(
self
)
->
bool
:
def
clear
(
self
)
->
bool
:
try
:
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment