Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70cf4abc
Unverified
Commit
70cf4abc
authored
Aug 22, 2025
by
pansicheng
Committed by
GitHub
Aug 22, 2025
Browse files
3fs zerocopy (#9109)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
cebf4599
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
310 additions
and
29 deletions
+310
-29
benchmark/hf3fs/bench.sh
benchmark/hf3fs/bench.sh
+10
-0
benchmark/hf3fs/bench_storage.py
benchmark/hf3fs/bench_storage.py
+27
-8
benchmark/hf3fs/bench_zerocopy.py
benchmark/hf3fs/bench_zerocopy.py
+140
-0
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+53
-8
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+32
-0
python/sglang/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md
...g/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md
+5
-2
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
+43
-11
No files found.
benchmark/hf3fs/bench.sh
View file @
70cf4abc
export
LD_LIBRARY_PATH
=
$LD_LIBRARY_PATH
:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
python3 benchmark/hf3fs/bench_client.py
export
LD_LIBRARY_PATH
=
$LD_LIBRARY_PATH
:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
SGLANG_HICACHE_HF3FS_CONFIG_PATH
=
/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json
\
python3 benchmark/hf3fs/bench_storage.py
export
LD_LIBRARY_PATH
=
$LD_LIBRARY_PATH
:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
export
SGLANG_HICACHE_HF3FS_CONFIG_PATH
=
/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json
echo
'{"file_path_prefix": "/data/hf3fs-test-0", "file_size": 1099511627776, "numjobs": 16, "entries": 8}'
>
\
${
SGLANG_HICACHE_HF3FS_CONFIG_PATH
}
python3 benchmark/hf3fs/bench_zerocopy.py
####################################################################################################
rm
-rf
nohup.out
&&
\
...
...
benchmark/hf3fs/bench_storage.py
View file @
70cf4abc
...
...
@@ -8,6 +8,9 @@ from typing import List
import
torch
from
tqdm
import
tqdm
from
sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server
import
(
Hf3fsLocalMetadataClient
,
)
from
sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs
import
HiCacheHF3FS
...
...
@@ -67,12 +70,15 @@ def test():
k
=
f
"key_
{
i
}
"
v
=
torch
.
randn
((
numel
,)).
to
(
dtype
=
dtype
)
ok
=
hicache_hf3fs
.
set
(
k
,
v
)
assert
ok
,
f
"Failed to insert
{
k
}
"
if
i
<
(
file_size
//
bytes_per_page
):
assert
ok
,
f
"Failed to insert
{
k
}
"
else
:
assert
not
ok
tensors
[
k
]
=
v
assert
hicache_hf3fs
.
get
(
"key_
0
"
)
is
None
assert
hicache_hf3fs
.
get
(
"key_
1
"
)
is
None
assert
hicache_hf3fs
.
get
(
"key_
8
"
)
is
None
assert
hicache_hf3fs
.
get
(
"key_
9
"
)
is
None
start
=
num_pages
-
hicache_hf3fs
.
num_pages
start
=
0
for
i
in
range
(
start
,
start
+
hicache_hf3fs
.
num_pages
):
k
=
f
"key_
{
i
}
"
assert
hicache_hf3fs
.
exists
(
k
)
...
...
@@ -83,13 +89,16 @@ def test():
assert
not
hicache_hf3fs
.
exists
(
"not_exists"
)
hicache_hf3fs
.
delete
(
"key_
9
"
)
hicache_hf3fs
.
delete
(
"key_
7
"
)
v2
=
torch
.
randn
((
numel
,)).
to
(
dtype
=
dtype
)
assert
hicache_hf3fs
.
set
(
"key_new"
,
v2
)
assert
torch
.
allclose
(
hicache_hf3fs
.
get
(
"key_new"
),
v2
,
atol
=
1e-3
)
hicache_hf3fs
.
clear
()
assert
len
(
hicache_hf3fs
.
free_pages
)
==
hicache_hf3fs
.
num_pages
assert
(
len
(
hicache_hf3fs
.
metadata_client
.
rank_metadata
.
free_pages
)
==
hicache_hf3fs
.
metadata_client
.
rank_metadata
.
num_pages
)
# batch
num_pages
=
10
...
...
@@ -134,12 +143,14 @@ def bench():
entries
=
8
dtype
=
store_dtype
hicache_hf3fs
=
HiCacheHF3FS
(
rank
=
0
,
file_path
=
file_path
,
file_size
=
file_size
,
numjobs
=
numjobs
,
bytes_per_page
=
bytes_per_page
,
entries
=
entries
,
dtype
=
dtype
,
metadata_client
=
Hf3fsLocalMetadataClient
(),
)
numel
=
2
*
tokens_per_page
*
layer_num
*
head_num
*
head_dim
...
...
@@ -167,7 +178,10 @@ def bench():
r_bw
=
[]
r_size
=
num_page
*
bytes_per_page
/
(
1
<<
30
)
for
i
in
tqdm
(
range
(
warmup
+
iteration
),
desc
=
"Benchmarking read (GB/s)"
):
keys
=
random
.
sample
(
list
(
hicache_hf3fs
.
key_to_index
.
keys
()),
num_page
)
keys
=
random
.
sample
(
list
(
hicache_hf3fs
.
metadata_client
.
rank_metadata
.
key_to_index
.
keys
()),
num_page
,
)
tik
=
time
.
perf_counter
()
results
=
hicache_hf3fs
.
batch_get
(
keys
)
tok
=
time
.
perf_counter
()
...
...
@@ -195,12 +209,14 @@ def allclose():
entries
=
8
dtype
=
store_dtype
hicache_hf3fs
=
HiCacheHF3FS
(
rank
=
0
,
file_path
=
file_path
,
file_size
=
file_size
,
numjobs
=
numjobs
,
bytes_per_page
=
bytes_per_page
,
entries
=
entries
,
dtype
=
dtype
,
metadata_client
=
Hf3fsLocalMetadataClient
(),
)
numel
=
2
*
tokens_per_page
*
layer_num
*
head_num
*
head_dim
...
...
@@ -218,7 +234,10 @@ def allclose():
read_keys
,
read_results
=
[],
[]
for
i
in
tqdm
(
range
(
iteration
),
desc
=
"Benchmarking read (GB/s)"
):
keys
=
random
.
sample
(
list
(
hicache_hf3fs
.
key_to_index
.
keys
()),
num_page
)
keys
=
random
.
sample
(
list
(
hicache_hf3fs
.
metadata_client
.
rank_metadata
.
key_to_index
.
keys
()),
num_page
,
)
results
=
hicache_hf3fs
.
batch_get
(
keys
)
read_keys
.
extend
(
keys
)
read_results
.
extend
(
results
)
...
...
benchmark/hf3fs/bench_zerocopy.py
0 → 100644
View file @
70cf4abc
import
threading
import
time
import
torch
from
tqdm
import
tqdm
from
sglang.srt.distributed
import
(
get_world_group
,
init_distributed_environment
,
initialize_model_parallel
,
)
from
sglang.srt.managers.cache_controller
import
(
HiCacheController
,
PrefetchOperation
,
StorageOperation
,
)
from
sglang.srt.mem_cache.allocator
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
from
sglang.srt.mem_cache.memory_pool_host
import
MHATokenToKVPoolHost
init_distributed_environment
(
world_size
=
1
,
rank
=
0
,
distributed_init_method
=
"tcp://127.0.0.1:23456"
,
local_rank
=
0
,
backend
=
"gloo"
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
1
,
pipeline_model_parallel_size
=
1
,
)
group
=
get_world_group
().
cpu_group
max_total_num_tokens
=
524288
page_size
=
64
kv_cache_dtype
=
torch
.
bfloat16
layer_num
=
64
head_num
,
head_dim
=
8
,
128
device
=
"cuda"
hicache_ratio
=
2
hicache_size
=
0
hicache_mem_layout
=
"page_first"
# hicache_mem_layout = "layer_first"
hicache_write_policy
=
"write_through"
hicache_io_backend
=
"kernel"
hicache_storage_backend
=
"hf3fs"
prefetch_threshold
=
256
op_size
=
1024
op_num
=
16
token_to_kv_pool
=
MHATokenToKVPool
(
max_total_num_tokens
,
page_size
=
page_size
,
dtype
=
kv_cache_dtype
,
head_num
=
head_num
,
head_dim
=
head_dim
,
layer_num
=
layer_num
,
device
=
device
,
enable_memory_saver
=
True
,
)
token_to_kv_pool_allocator
=
TokenToKVPoolAllocator
(
max_total_num_tokens
,
dtype
=
kv_cache_dtype
,
device
=
device
,
kvcache
=
token_to_kv_pool
,
need_sort
=
False
,
)
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
kv_cache
,
hicache_ratio
,
hicache_size
,
page_size
,
hicache_mem_layout
,
)
load_cache_event
=
threading
.
Event
()
cache_controller
=
HiCacheController
(
token_to_kv_pool_allocator
,
token_to_kv_pool_host
,
page_size
,
group
,
load_cache_event
=
load_cache_event
,
write_policy
=
hicache_write_policy
,
io_backend
=
hicache_io_backend
,
storage_backend
=
hicache_storage_backend
,
prefetch_threshold
=
prefetch_threshold
,
)
operations
=
[
StorageOperation
(
torch
.
tensor
(
list
(
range
(
i
,
i
+
op_size
))),
list
(
range
(
i
,
i
+
op_size
)),
hash_value
=
[
f
"
{
j
}
"
for
j
in
range
(
i
,
i
+
op_size
,
page_size
)],
)
for
i
in
tqdm
(
range
(
0
,
op_num
*
op_size
,
op_size
))
]
tik
=
time
.
monotonic
()
if
hicache_mem_layout
==
"page_first"
:
for
operation
in
operations
:
cache_controller
.
zerocopy_page_backup
(
operation
,
batch_size
=
128
)
elif
hicache_mem_layout
==
"layer_first"
:
for
operation
in
operations
:
cache_controller
.
generic_page_backup
(
operation
,
batch_size
=
128
)
tok
=
time
.
monotonic
()
print
(
f
"
{
tok
-
tik
:.
6
f
}
s"
)
operations
=
[
PrefetchOperation
(
f
"
{
i
}
"
,
torch
.
tensor
(
list
(
range
(
i
,
i
+
op_size
))),
list
(
range
(
i
,
i
+
op_size
)),
f
"
{
i
}
"
,
)
for
i
in
tqdm
(
range
(
0
,
op_num
*
op_size
,
op_size
))
]
for
operation
in
operations
:
operation
.
hash_value
=
[
f
"
{
j
}
"
for
j
in
range
(
int
(
operation
.
last_hash
),
int
(
operation
.
last_hash
)
+
op_size
,
page_size
)
]
tik
=
time
.
monotonic
()
if
hicache_mem_layout
==
"page_first"
:
for
operation
in
operations
:
cache_controller
.
zerocopy_page_transfer
(
operation
,
batch_size
=
128
)
elif
hicache_mem_layout
==
"layer_first"
:
for
operation
in
operations
:
cache_controller
.
generic_page_transfer
(
operation
,
batch_size
=
128
)
tok
=
time
.
monotonic
()
print
(
f
"
{
tok
-
tik
:.
6
f
}
s"
)
python/sglang/srt/managers/cache_controller.py
View file @
70cf4abc
...
...
@@ -268,9 +268,14 @@ class HiCacheController:
)
rank
=
get_tensor_model_parallel_rank
()
bytes_per_page
=
(
mem_pool_host
.
get_size_per_token
()
*
mem_pool_host
.
page_size
)
if
self
.
mem_pool_host
.
layout
==
"page_first"
:
bytes_per_page
=
(
mem_pool_host
.
get_ksize_per_token
()
*
mem_pool_host
.
page_size
)
elif
self
.
mem_pool_host
.
layout
==
"layer_first"
:
bytes_per_page
=
(
mem_pool_host
.
get_size_per_token
()
*
mem_pool_host
.
page_size
)
dtype
=
mem_pool_host
.
dtype
self
.
storage_backend
=
HiCacheHF3FS
.
from_env_config
(
rank
,
bytes_per_page
,
dtype
...
...
@@ -555,13 +560,34 @@ class HiCacheController:
operation
.
mark_done
()
return
operation
.
completed_tokens
,
operation
.
hash_value
def
zerocopy_page_transfer
(
self
,
operation
,
batch_size
=
8
):
hashes
,
dsts
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
operation
.
hash_value
,
operation
.
host_indices
)
for
i
in
range
(
0
,
len
(
hashes
),
batch_size
):
page_hashes
=
hashes
[
i
:
i
+
batch_size
]
page_dsts
=
dsts
[
i
:
i
+
batch_size
]
page_data
=
self
.
storage_backend
.
batch_get
(
page_hashes
,
page_dsts
)
if
page_data
is
None
:
logger
.
warning
(
f
"Prefetch operation
{
operation
.
request_id
}
failed to retrieve page
{
page_hashes
}
."
)
break
completed_tokens
=
operation
.
completed_tokens
if
operation
.
increment
(
self
.
page_size
*
len
(
page_hashes
)):
for
i
in
range
(
len
(
page_hashes
)):
completed_tokens
+=
self
.
page_size
else
:
break
def
generic_page_transfer
(
self
,
operation
,
batch_size
=
8
):
for
i
in
range
(
0
,
len
(
operation
.
hash_value
),
batch_size
):
page_hashes
=
operation
.
hash_value
[
i
:
i
+
batch_size
]
# todo: zero copy
dummy_page_dst
=
[
self
.
mem_pool_host
.
get_dummy_flat_data_page
()]
*
len
(
page_hashes
)
dummy_page_dst
=
[
self
.
mem_pool_host
.
get_dummy_flat_data_page
()
for
_
in
range
(
len
(
page_hashes
))
]
page_data
=
self
.
storage_backend
.
batch_get
(
page_hashes
,
dummy_page_dst
)
if
page_data
is
None
:
logger
.
warning
(
...
...
@@ -599,7 +625,10 @@ class HiCacheController:
if
self
.
is_mooncake_backend
():
self
.
mooncake_page_transfer
(
operation
)
elif
self
.
storage_backend_type
==
"hf3fs"
:
self
.
generic_page_transfer
(
operation
,
batch_size
=
128
)
if
self
.
mem_pool_host
.
layout
==
"page_first"
:
self
.
zerocopy_page_transfer
(
operation
,
batch_size
=
128
)
elif
self
.
mem_pool_host
.
layout
==
"layer_first"
:
self
.
generic_page_transfer
(
operation
,
batch_size
=
128
)
else
:
self
.
generic_page_transfer
(
operation
)
...
...
@@ -716,6 +745,19 @@ class HiCacheController:
self
.
backup_queue
.
put
(
operation
)
return
operation
.
id
def
zerocopy_page_backup
(
self
,
operation
,
batch_size
=
8
):
hashes
,
dsts
=
self
.
mem_pool_host
.
get_buffer_with_hash
(
operation
.
hash_value
,
operation
.
host_indices
)
for
i
in
range
(
0
,
len
(
hashes
),
batch_size
):
page_hashes
=
hashes
[
i
:
i
+
batch_size
]
page_data
=
dsts
[
i
:
i
+
batch_size
]
success
=
self
.
storage_backend
.
batch_set
(
page_hashes
,
page_data
)
if
not
success
:
logger
.
warning
(
f
"Failed to write page
{
page_hashes
}
to storage."
)
break
operation
.
completed_tokens
+=
self
.
page_size
*
len
(
page_hashes
)
def
generic_page_backup
(
self
,
operation
,
batch_size
=
8
):
for
i
in
range
(
0
,
len
(
operation
.
hash_value
),
batch_size
):
page_hashes
=
operation
.
hash_value
[
i
:
i
+
batch_size
]
...
...
@@ -770,7 +812,10 @@ class HiCacheController:
if
self
.
is_mooncake_backend
():
self
.
mooncake_page_backup
(
operation
)
elif
self
.
storage_backend_type
==
"hf3fs"
:
self
.
generic_page_backup
(
operation
,
batch_size
=
128
)
if
self
.
mem_pool_host
.
layout
==
"page_first"
:
self
.
zerocopy_page_backup
(
operation
,
batch_size
=
128
)
elif
self
.
mem_pool_host
.
layout
==
"layer_first"
:
self
.
generic_page_backup
(
operation
,
batch_size
=
128
)
else
:
self
.
generic_page_backup
(
operation
)
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
70cf4abc
...
...
@@ -307,6 +307,9 @@ class MHATokenToKVPoolHost(HostKVCache):
return
self
.
head_dim
*
self
.
head_num
*
self
.
layer_num
*
self
.
dtype
.
itemsize
*
2
def
get_ksize_per_token
(
self
):
return
self
.
get_size_per_token
()
//
2
def
init_kv_buffer
(
self
):
if
self
.
layout
==
"layer_first"
:
dims
=
(
2
,
self
.
layer_num
,
self
.
size
,
self
.
head_num
,
self
.
head_dim
)
...
...
@@ -496,6 +499,21 @@ class MHATokenToKVPoolHost(HostKVCache):
element_size_list
=
[
element_size
]
*
len
(
key_list
)
return
key_list
,
ptr_list
,
element_size_list
def
get_buffer_with_hash
(
self
,
keys
,
indices
):
assert
self
.
layout
==
"page_first"
assert
len
(
keys
)
==
(
len
(
indices
)
//
self
.
page_size
)
key_list
=
[]
buf_list
=
[]
for
key
,
i
in
zip
(
keys
,
range
(
0
,
len
(
indices
),
self
.
page_size
)):
key_list
.
append
(
f
"
{
key
}
-k"
)
buf_list
.
append
(
self
.
k_buffer
[
i
:
i
+
self
.
page_size
])
key_list
.
append
(
f
"
{
key
}
-v"
)
buf_list
.
append
(
self
.
v_buffer
[
i
:
i
+
self
.
page_size
])
return
key_list
,
buf_list
class
MLATokenToKVPoolHost
(
HostKVCache
):
device_pool
:
MLATokenToKVPool
...
...
@@ -538,6 +556,9 @@ class MLATokenToKVPoolHost(HostKVCache):
*
self
.
layer_num
)
def
get_ksize_per_token
(
self
):
return
self
.
get_size_per_token
()
def
init_kv_buffer
(
self
):
if
self
.
layout
==
"layer_first"
:
dims
=
(
...
...
@@ -704,3 +725,14 @@ class MLATokenToKVPoolHost(HostKVCache):
)
element_size_list
=
[
element_size
]
*
len
(
key_list
)
return
key_list
,
ptr_list
,
element_size_list
def
get_buffer_with_hash
(
self
,
keys
,
indices
):
assert
self
.
layout
==
"page_first"
assert
len
(
keys
)
==
(
len
(
indices
)
//
self
.
page_size
)
buf_list
=
[]
for
i
in
range
(
0
,
len
(
indices
),
self
.
page_size
):
buf_list
.
append
(
self
.
kv_buffer
[
i
:
i
+
self
.
page_size
])
return
keys
,
buf_list
python/sglang/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md
View file @
70cf4abc
...
...
@@ -34,6 +34,9 @@ apt-get update \
python3 python3-pip
\
&&
apt-get clean
\
&&
rm
-rf
/var/lib/apt/lists/
*
# apt install python3.12 python3.12-venv python3.12-dev
# curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
# python3.12 get-pip.py
# Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
python3 setup.py bdist_wheel
...
...
@@ -60,6 +63,6 @@ apt update && apt install -y \
libuv1-dev
# Install Python Package
pip
install
hf3fs_py_usrbio-1.2.9+
2db69ce
-cp31
0
-cp31
0
-linux_x86_64.whl
export
LD_LIBRARY_PATH
=
$LD_LIBRARY_PATH
:/usr/local/lib/python3.1
0
/dist-packages
pip
install
hf3fs_py_usrbio-1.2.9+
394583d
-cp31
2
-cp31
2
-linux_x86_64.whl
export
LD_LIBRARY_PATH
=
$LD_LIBRARY_PATH
:/usr/local/lib/python3.1
2
/dist-packages
```
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
View file @
70cf4abc
...
...
@@ -7,7 +7,7 @@ import signal
import
threading
from
abc
import
ABC
,
abstractmethod
from
functools
import
wraps
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Any
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -228,15 +228,23 @@ class HiCacheHF3FS(HiCacheStorage):
)
def
get
(
self
,
key
:
str
,
target_location
:
Optional
[
torch
.
Tensor
]
=
None
self
,
key
:
str
,
target_location
:
Optional
[
Any
]
=
None
,
target_sizes
:
Optional
[
Any
]
=
None
,
)
->
torch
.
Tensor
|
None
:
return
self
.
batch_get
([
key
],
[
target_location
]
if
target_location
else
None
)[
0
]
return
self
.
batch_get
(
[
key
],
[
target_location
]
if
target_location
is
not
None
else
None
,
[
target_sizes
]
if
target_sizes
is
not
None
else
None
,
)[
0
]
@
synchronized
()
def
batch_get
(
self
,
keys
:
List
[
str
],
target_locations
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
target_locations
:
Optional
[
Any
]
=
None
,
target_sizes
:
Optional
[
Any
]
=
None
,
)
->
List
[
torch
.
Tensor
|
None
]:
page_indices
=
self
.
metadata_client
.
get_page_indices
(
self
.
rank
,
keys
)
...
...
@@ -246,9 +254,15 @@ class HiCacheHF3FS(HiCacheStorage):
batch_indices
.
append
(
i
)
file_offsets
.
append
(
page_index
*
self
.
bytes_per_page
)
file_results
=
[
torch
.
empty
(
self
.
numel
,
dtype
=
self
.
dtype
)
for
_
in
range
(
len
(
batch_indices
))
]
if
target_locations
is
not
None
:
for
target_location
in
target_locations
:
assert
target_location
.
is_contiguous
()
file_results
=
target_locations
else
:
file_results
=
[
torch
.
empty
(
self
.
numel
,
dtype
=
self
.
dtype
)
for
_
in
range
(
len
(
batch_indices
))
]
futures
=
[
self
.
executor
.
submit
(
...
...
@@ -273,10 +287,27 @@ class HiCacheHF3FS(HiCacheStorage):
return
results
def
set
(
self
,
key
:
str
,
value
:
torch
.
Tensor
)
->
bool
:
return
self
.
batch_set
([
key
],
[
value
])
def
set
(
self
,
key
:
str
,
value
:
Optional
[
Any
]
=
None
,
target_location
:
Optional
[
Any
]
=
None
,
target_sizes
:
Optional
[
Any
]
=
None
,
)
->
bool
:
return
self
.
batch_set
(
[
key
],
[
value
]
if
value
is
not
None
else
None
,
[
target_location
]
if
target_location
is
not
None
else
None
,
[
target_sizes
]
if
target_sizes
is
not
None
else
None
,
)
def
batch_set
(
self
,
keys
:
List
[
str
],
values
:
List
[
torch
.
Tensor
])
->
bool
:
def
batch_set
(
self
,
keys
:
List
[
str
],
values
:
Optional
[
Any
]
=
None
,
target_locations
:
Optional
[
Any
]
=
None
,
target_sizes
:
Optional
[
Any
]
=
None
,
)
->
bool
:
# Todo: Add prefix block's hash key
key_with_prefix
=
[(
key
,
""
)
for
key
in
keys
]
indices
=
self
.
metadata_client
.
reserve_and_allocate_page_indices
(
...
...
@@ -292,7 +323,8 @@ class HiCacheHF3FS(HiCacheStorage):
batch_indices
.
append
(
i
)
file_offsets
.
append
(
page_index
*
self
.
bytes_per_page
)
file_values
.
append
(
value
.
contiguous
())
assert
value
.
is_contiguous
()
file_values
.
append
(
value
)
futures
=
[
self
.
executor
.
submit
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment