"tests/vscode:/vscode.git/clone" did not exist on "9a00cf194fcf994b2527cd927d691144f5e9c47b"
Unverified Commit 29980334 authored by pansicheng's avatar pansicheng Committed by GitHub
Browse files

Add hf3fs support for hicache storage (based on #7704) (#7280)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent a79a5d70
SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \
python3 benchmark/hf3fs/bench_storage.py
####################################################################################################
rm -rf nohup.out && \
nohup python3 -m sglang.launch_server \
--model-path /code/models/Qwen3-32B/ \
--host 0.0.0.0 --port 33301 \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2 --hicache-size 0 \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs &
rm -rf bench_multiturn.out && \
nohup python3 benchmark/hicache/bench_multiturn.py \
--model-path /code/models/Qwen3-32B \
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
--port 33301 \
--request-length 2048 --num-clients 512 --num-rounds 3 --max-parallel 8 \
> bench_multiturn.out &
####################################################################################################
rm -rf nohup.out && \
nohup python3 -m sglang.launch_server \
--model-path /code/models/DeepSeek-R1/ \
--tp 16 --nnodes 2 --node-rank 0 \
--dist-init-addr 10.74.249.153:5000 \
--host 0.0.0.0 --port 33301 \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2 --hicache-size 60 \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs &
rm -rf bench_multiturn.out && \
nohup python3 benchmark/hicache/bench_multiturn.py \
--model-path /code/models/Qwen3-32B \
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
--port 33301 \
--request-length 2048 --num-clients 1024 --num-rounds 3 --max-parallel 8 \
> bench_multiturn.out &
####################################################################################################
ps aux | grep "sglang.launch_server" | grep -v grep | awk '{print $2}' | xargs kill -9
ps aux | grep "bench_multiturn.py" | grep -v grep | awk '{print $2}' | xargs kill -9
import concurrent.futures
import logging
import random
import time
from typing import List
import torch
from tqdm import tqdm
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
def print_stats(x: List[int]):
x = sorted(x)
lenx = len(x)
print(
f"mean = {sum(x)/len(x):.2f}, "
f"min = {min(x):.2f}, "
f"p25 = {x[int(lenx*0.25)]:.2f}, "
f"p50 = {x[int(lenx*0.5)]:.2f}, "
f"p75 = {x[int(lenx*0.75)]:.2f}, "
f"max = {max(x):.2f}"
)
def test():
# /path/to/hf3fs
file_path = "/data/bench.bin"
file_size = 1 << 40
bytes_per_page = 16 << 20
entries = 32
file_ops = Hf3fsClient(file_path, file_size, bytes_per_page, entries)
print("test batch_read / batch_write")
num_pages = 128
dtype = torch.bfloat16
numel = bytes_per_page // dtype.itemsize
offsets = list(range(file_size // bytes_per_page))
random.shuffle(offsets)
offsets = offsets[:num_pages]
offsets = [i * bytes_per_page for i in offsets]
tensor_writes = [
torch.randn(numel, dtype=dtype)
for _ in tqdm(range(num_pages), desc="prepare tensor")
]
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_write"):
results = file_ops.batch_write(
offsets[i : i + file_ops.entries], tensor_writes[i : i + file_ops.entries]
)
assert all([result == numel * dtype.itemsize for result in results])
tensor_reads = [
torch.empty(numel, dtype=dtype)
for _ in tqdm(range(num_pages), desc="prepare tensor")
]
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_read"):
results = file_ops.batch_read(
offsets[i : i + file_ops.entries], tensor_reads[i : i + file_ops.entries]
)
assert all([result == numel * dtype.itemsize for result in results])
assert all([torch.allclose(r, w) for r, w in zip(tensor_reads, tensor_writes)])
file_ops.close()
print("test done")
def bench():
file_path = "/data/bench.bin"
file_size = 1 << 40
bytes_per_page = 16 << 20
entries = 8
numjobs = 16
dtype = torch.bfloat16
numel = bytes_per_page // dtype.itemsize
file_ops = [
Hf3fsClient(file_path, file_size, bytes_per_page, entries)
for _ in range(numjobs)
]
num_page = entries
offsets = list(range(file_size // bytes_per_page))
tensors_write = [torch.randn(numel, dtype=dtype)] * num_page
tensors_read = [torch.empty(numel, dtype=dtype)] * num_page
random.shuffle(offsets)
warmup = 50
iteration = 100
executor = concurrent.futures.ThreadPoolExecutor(max_workers=numjobs)
w_bw = []
w_size = num_page * numjobs * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
_offsets = [
[
offset * bytes_per_page
for offset in offsets[
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
]
]
for j in range(numjobs)
]
tik = time.perf_counter()
futures = [
executor.submit(file_ops[j].batch_write, offset, tensors_write)
for j, offset in enumerate(_offsets)
]
results = [future.result() for future in futures]
tok = time.perf_counter()
if i < warmup:
continue
w_bw.append(w_size / (tok - tik))
results = [
_result == bytes_per_page for result in results for _result in result
]
assert all(results)
print_stats(w_bw)
r_bw = []
r_size = w_size
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
_offsets = [
[
offset * bytes_per_page
for offset in offsets[
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
]
]
for j in range(numjobs)
]
tik = time.perf_counter()
futures = [
executor.submit(file_ops[j].batch_read, offset, tensors_read)
for j, offset in enumerate(_offsets)
]
results = [future.result() for future in futures]
tok = time.perf_counter()
if i < warmup:
continue
r_bw.append(r_size / (tok - tik))
results = [
_result == bytes_per_page for result in results for _result in result
]
assert all(results)
print_stats(r_bw)
executor.shutdown(wait=True)
for _file_ops in file_ops:
_file_ops.close()
print("bench done")
def main():
logging.basicConfig(level=logging.INFO)
test()
bench()
if __name__ == "__main__":
main()
import json
import logging
import os
import random
import time
from typing import List
import torch
from tqdm import tqdm
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
def print_stats(x: List[int]):
x = sorted(x)
lenx = len(x)
print(
f"mean = {sum(x)/len(x):.2f}, "
f"min = {min(x):.2f}, "
f"p25 = {x[int(lenx*0.25)]:.2f}, "
f"p50 = {x[int(lenx*0.5)]:.2f}, "
f"p75 = {x[int(lenx*0.75)]:.2f}, "
f"max = {max(x):.2f}"
)
def test():
# Qwen3-32B
layer_num = 64
head_num, head_dim = 8, 128
kv_lora_rank, qk_rope_head_dim = 0, 0
store_dtype = torch.bfloat16
tokens_per_page = 64
file_path_prefix = "/data/test"
file_size = 128 << 20
numjobs = 16
bytes_per_page = 16 << 20
entries = 2
dtype = store_dtype
config_path = os.getenv(HiCacheHF3FS.default_env_var)
assert config_path
try:
with open(config_path, "w") as f:
json.dump(
{
"file_path_prefix": file_path_prefix,
"file_size": file_size,
"numjobs": numjobs,
"entries": entries,
},
f,
)
except Exception as e:
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
rank = 0
hicache_hf3fs = HiCacheHF3FS.from_env_config(rank, bytes_per_page, dtype)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
num_pages = 10
tensors = {}
for i in range(num_pages):
k = f"key_{i}"
v = torch.randn((numel,)).to(dtype=dtype)
ok = hicache_hf3fs.set(k, v)
assert ok, f"Failed to insert {k}"
tensors[k] = v
assert hicache_hf3fs.get("key_0") is None
assert hicache_hf3fs.get("key_1") is None
start = num_pages - hicache_hf3fs.num_pages
for i in range(start, start + hicache_hf3fs.num_pages):
k = f"key_{i}"
assert hicache_hf3fs.exists(k)
out = hicache_hf3fs.get(k)
assert out is not None
v = tensors[k]
assert torch.allclose(v, out, atol=1e-3), f"Tensor mismatch for {k}"
assert not hicache_hf3fs.exists("not_exists")
hicache_hf3fs.delete("key_9")
v2 = torch.randn((numel,)).to(dtype=dtype)
assert hicache_hf3fs.set("key_new", v2)
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
hicache_hf3fs.clear()
assert len(hicache_hf3fs.free_pages) == hicache_hf3fs.num_pages
# batch
num_pages = 10
tensors = {}
keys = []
values = []
for i in range(num_pages):
k = f"key_{i}"
keys.append(k)
v = torch.randn((numel,)).to(dtype=dtype)
values.append(v)
ok = hicache_hf3fs.batch_set(keys, values)
assert not ok
assert hicache_hf3fs.get("key_8") is None
assert hicache_hf3fs.get("key_9") is None
results = hicache_hf3fs.batch_get(keys[: hicache_hf3fs.num_pages])
for result, key, value in zip(
results, keys[: hicache_hf3fs.num_pages], values[: hicache_hf3fs.num_pages]
):
assert torch.allclose(value, result, atol=1e-3), f"Tensor mismatch for {key}"
hicache_hf3fs.close()
os.remove(hicache_hf3fs.file_path)
print("All test cases passed.")
def bench():
# Qwen3-32B
layer_num = 64
head_num, head_dim = 8, 128
kv_lora_rank, qk_rope_head_dim = 0, 0
store_dtype = torch.bfloat16
tokens_per_page = 64
file_path = "/data/test.bin"
file_size = 1 << 40
numjobs = 16
bytes_per_page = 16 << 20
entries = 8
dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS(
file_path=file_path,
file_size=file_size,
numjobs=numjobs,
bytes_per_page=bytes_per_page,
entries=entries,
dtype=dtype,
)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
num_page = 128
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
warmup = 50
iteration = 100
w_bw = []
w_size = num_page * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
tik = time.perf_counter()
ok = hicache_hf3fs.batch_set(keys, values)
tok = time.perf_counter()
if i < warmup:
continue
w_bw.append(w_size / (tok - tik))
assert ok
print_stats(w_bw)
r_bw = []
r_size = num_page * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
tik = time.perf_counter()
results = hicache_hf3fs.batch_get(keys)
tok = time.perf_counter()
if i < warmup:
continue
r_bw.append(r_size / (tok - tik))
assert all([r is not None for r in results])
print_stats(r_bw)
hicache_hf3fs.close()
def allclose():
# Qwen3-32B
layer_num = 64
head_num, head_dim = 8, 128
kv_lora_rank, qk_rope_head_dim = 0, 0
store_dtype = torch.bfloat16
tokens_per_page = 64
file_path = "/data/test.bin"
file_size = 1 << 40
numjobs = 16
bytes_per_page = 16 << 20
entries = 8
dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS(
file_path=file_path,
file_size=file_size,
numjobs=numjobs,
bytes_per_page=bytes_per_page,
entries=entries,
dtype=dtype,
)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
num_page = 128
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
iteration = 100
for i in tqdm(range(iteration), desc="Benchmarking write (GB/s)"):
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
ok = hicache_hf3fs.batch_set(keys, values)
assert ok
read_keys, read_results = [], []
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
results = hicache_hf3fs.batch_get(keys)
read_keys.extend(keys)
read_results.extend(results)
assert all([r is not None for r in results])
for key, result in tqdm(zip(read_keys, read_results)):
assert torch.allclose(values[int(key) % num_page], result, atol=1e-3)
hicache_hf3fs.close()
def main():
logging.basicConfig(level=logging.INFO)
test()
bench()
allclose()
if __name__ == "__main__":
main()
...@@ -26,6 +26,7 @@ if TYPE_CHECKING: ...@@ -26,6 +26,7 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -250,17 +251,33 @@ class HiCacheController: ...@@ -250,17 +251,33 @@ class HiCacheController:
self.tp_world_size = torch.distributed.get_world_size(group=tp_group) self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1: if self.tp_world_size > 1:
group_ranks = torch.distributed.get_process_group_ranks(tp_group) group_ranks = torch.distributed.get_process_group_ranks(tp_group)
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo") self.prefetch_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.backup_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
if storage_backend == "file": if storage_backend == "file":
self.storage_backend = HiCacheFile() self.storage_backend = HiCacheFile()
self.enable_storage = True elif storage_backend == "hf3fs":
# todo: threshold policy for prefetching from sglang.srt.distributed import get_tensor_model_parallel_rank
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
rank = get_tensor_model_parallel_rank()
bytes_per_page = (
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
)
dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config(
rank, bytes_per_page, dtype
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}" f"Unsupported storage backend: {storage_backend}"
) )
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
self.load_cache_event = load_cache_event self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
...@@ -522,8 +539,8 @@ class HiCacheController: ...@@ -522,8 +539,8 @@ class HiCacheController:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.prefetch_buffer.get(block=True, timeout=1) operation = self.prefetch_buffer.get(block=True, timeout=1)
for h in operation.hash_value: page_datas = self.storage_backend.batch_get(operation.hash_value)
page_data = self.storage_backend.get(h) for h, page_data in zip(operation.hash_value, page_datas):
if page_data is None: if page_data is None:
logger.warning( logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {h}." f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
...@@ -531,7 +548,9 @@ class HiCacheController: ...@@ -531,7 +548,9 @@ class HiCacheController:
break break
if operation.increment(self.page_size): if operation.increment(self.page_size):
self.mem_pool_host.set_from_flat_data_page( self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[operation.completed_tokens], operation.host_indices[
operation.completed_tokens - self.page_size
],
page_data, page_data,
) )
else: else:
...@@ -583,7 +602,7 @@ class HiCacheController: ...@@ -583,7 +602,7 @@ class HiCacheController:
torch.distributed.all_reduce( torch.distributed.all_reduce(
storage_hit_count_tensor, storage_hit_count_tensor,
op=torch.distributed.ReduceOp.MIN, op=torch.distributed.ReduceOp.MIN,
group=self.tp_group, group=self.prefetch_tp_group,
) )
storage_hit_count = storage_hit_count_tensor.item() storage_hit_count = storage_hit_count_tensor.item()
...@@ -635,21 +654,23 @@ class HiCacheController: ...@@ -635,21 +654,23 @@ class HiCacheController:
last_hash = operation.last_hash last_hash = operation.last_hash
tokens_to_backup = operation.token_ids tokens_to_backup = operation.token_ids
last_hashes, data_pages = [], []
for i in range(0, len(tokens_to_backup), self.page_size): for i in range(0, len(tokens_to_backup), self.page_size):
last_hash = get_hash_str( last_hash = get_hash_str(
tokens_to_backup[i : i + self.page_size], last_hash tokens_to_backup[i : i + self.page_size], last_hash
) )
success = self.storage_backend.set( data_page = self.mem_pool_host.get_flat_data_page(
last_hash, operation.host_indices[i]
self.mem_pool_host.get_flat_data_page(
operation.host_indices[i]
),
) )
if not success: last_hashes.append(last_hash)
logger.warning(f"Failed to write page {last_hash} to storage.") data_pages.append(data_page)
break
operation.completed_tokens += self.page_size success = self.storage_backend.batch_set(last_hashes, data_pages)
operation.hash_value.append(last_hash) if not success:
logger.warning(f"Failed to write page {last_hashes} to storage.")
else:
operation.completed_tokens += len(tokens_to_backup)
operation.hash_value.extend(last_hashes)
min_completed_tokens = operation.completed_tokens min_completed_tokens = operation.completed_tokens
if self.tp_world_size > 1: if self.tp_world_size > 1:
...@@ -659,7 +680,7 @@ class HiCacheController: ...@@ -659,7 +680,7 @@ class HiCacheController:
torch.distributed.all_reduce( torch.distributed.all_reduce(
completed_tokens_tensor, completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN, op=torch.distributed.ReduceOp.MIN,
group=self.tp_group, group=self.backup_tp_group,
) )
min_completed_tokens = completed_tokens_tensor.item() min_completed_tokens = completed_tokens_tensor.item()
......
...@@ -79,7 +79,9 @@ class HiRadixCache(RadixCache): ...@@ -79,7 +79,9 @@ class HiRadixCache(RadixCache):
self.write_through_threshold = ( self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 3 1 if hicache_write_policy == "write_through" else 3
) )
self.write_through_threshold_storage = 3 self.write_through_threshold_storage = (
1 if hicache_write_policy == "write_through" else 3
)
self.load_back_threshold = 10 self.load_back_threshold = 10
super().__init__( super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
...@@ -388,10 +390,14 @@ class HiRadixCache(RadixCache): ...@@ -388,10 +390,14 @@ class HiRadixCache(RadixCache):
self.cache_controller.ack_backup_queue.get() self.cache_controller.ack_backup_queue.get()
) )
host_node = self.ongoing_backup[ack_id] host_node = self.ongoing_backup[ack_id]
if completed_tokens < len(host_node.key): if completed_tokens == 0:
host_node.hash_value = None
elif completed_tokens < len(host_node.key):
# backup is only partially successful, split the node # backup is only partially successful, split the node
new_node = self._split_node(host_node.key, host_node, completed_tokens) new_node = self._split_node(host_node.key, host_node, completed_tokens)
new_node.hash_value = hash_value new_node.hash_value = hash_value
else:
host_node.hash_value = hash_value
host_node.release_host() host_node.release_host()
del self.ongoing_backup[ack_id] del self.ongoing_backup[ack_id]
...@@ -431,6 +437,8 @@ class HiRadixCache(RadixCache): ...@@ -431,6 +437,8 @@ class HiRadixCache(RadixCache):
written_indices, written_indices,
hash_value[:min_completed_tokens], hash_value[:min_completed_tokens],
) )
if len(written_indices):
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.mem_pool_host.free( self.cache_controller.mem_pool_host.free(
......
...@@ -25,7 +25,6 @@ def synchronized(debug_only=False): ...@@ -25,7 +25,6 @@ def synchronized(debug_only=False):
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
if (not debug_only) or self.debug: if (not debug_only) or self.debug:
return func(self, *args, **kwargs)
with self.lock: with self.lock:
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
else: else:
...@@ -181,6 +180,15 @@ class HostKVCache(abc.ABC): ...@@ -181,6 +180,15 @@ class HostKVCache(abc.ABC):
) )
self.mem_state[indices] = MemoryStateInt.BACKUP self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_prefetch(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True) @synchronized(debug_only=True)
def update_synced(self, indices: torch.Tensor): def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED self.mem_state[indices] = MemoryStateInt.SYNCED
......
# HiCacheHF3FS Setup
## Build & Package
### Source Code
https://github.com/deepseek-ai/3FS/blob/main/README.md#check-out-source-code
```sh
git clone https://github.com/deepseek-ai/3fs
cd 3fs
git submodule update --init --recursive
./patches/apply.sh
```
### Build Dev Container
https://github.com/deepseek-ai/3FS/blob/main/dockerfile/dev.dockerfile
```sh
cd 3fs/dockerfile
docker build -t hf3fs:dev -f dev.dockerfile .
```
### Generate Python Wheel
```sh
docker run -it hf3fs:dev bash
# Inside the development container
git clone https://github.com/deepseek-ai/3fs
cd 3fs
git submodule update --init --recursive
./patches/apply.sh
apt-get update \
&& apt-get install -y --no-install-recommends \
python3 python3-pip \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
python3 setup.py bdist_wheel
```
## Installation
```sh
# Install Dependencies
# https://github.com/deepseek-ai/3FS/blob/main/dockerfile/dev.dockerfile
apt update && apt install -y \
libaio-dev \
libboost-all-dev \
libdouble-conversion-dev \
libdwarf-dev \
libgflags-dev \
libgmock-dev \
libgoogle-glog-dev \
libgoogle-perftools-dev \
libgtest-dev \
liblz4-dev \
liblzma-dev \
libssl-dev \
libunwind-dev \
libuv1-dev
# Install Python Package
pip install hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
```
import logging
import multiprocessing
import os
import threading
from functools import wraps
from pathlib import Path
from typing import List
import torch
from torch.utils.cpp_extension import load
root = Path(__file__).parent.resolve()
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
logger = logging.getLogger(__name__)
try:
from hf3fs_fuse.io import (
deregister_fd,
extract_mount_point,
make_ioring,
make_iovec,
register_fd,
)
except ImportError as e:
logger.warning(f"hf3fs_fuse.io is not available: {e}")
def rsynchronized():
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.rlock:
return func(self, *args, **kwargs)
return wrapper
return _decorator
def wsynchronized():
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.wlock:
return func(self, *args, **kwargs)
return wrapper
return _decorator
class Hf3fsClient:
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
self.path = path
self.size = size
self.bytes_per_page = bytes_per_page
self.entries = entries
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
os.ftruncate(self.file, size)
register_fd(self.file)
self.hf3fs_mount_point = extract_mount_point(path)
self.bs = self.bytes_per_page
self.shm_r = multiprocessing.shared_memory.SharedMemory(
size=self.bs * self.entries, create=True
)
self.shm_w = multiprocessing.shared_memory.SharedMemory(
size=self.bs * self.entries, create=True
)
self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8)
self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8)
self.numa = -1
self.ior_r = make_ioring(
self.hf3fs_mount_point,
self.entries,
for_read=True,
timeout=1,
numa=self.numa,
)
self.ior_w = make_ioring(
self.hf3fs_mount_point,
self.entries,
for_read=False,
timeout=1,
numa=self.numa,
)
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
self.rlock = threading.RLock()
self.wlock = threading.RLock()
@rsynchronized()
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
self.check(offsets, tensors)
# prepare
current = 0
for offset, tensor in zip(offsets, tensors):
size = tensor.numel() * tensor.itemsize
self.ior_r.prepare(
self.iov_r[current : current + size], True, self.file, offset
)
current += size
# submit
ionum = len(offsets)
resv = self.ior_r.submit().wait(min_results=ionum)
# results
hf3fs_utils.read_shm(self.shm_r_tensor, tensors)
results = [res.result for res in resv]
return results
@wsynchronized()
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
self.check(offsets, tensors)
# prepare
hf3fs_utils.write_shm(tensors, self.shm_w_tensor)
current = 0
for offset, tensor in zip(offsets, tensors):
size = tensor.numel() * tensor.itemsize
self.ior_w.prepare(
self.iov_w[current : current + size], False, self.file, offset
)
current += size
# submit
ionum = len(offsets)
resv = self.ior_w.submit().wait(min_results=ionum)
# results
results = [res.result for res in resv]
return results
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
sizes = [t.numel() * t.itemsize for t in tensors]
if any(
[
len(offsets) > self.entries,
len(offsets) != len(sizes),
all(
[
offset < 0 or offset + size > self.size
for offset, size in zip(offsets, sizes)
]
),
all([size > self.bytes_per_page for size in sizes]),
]
):
self.close()
raise ValueError(f"Hf3fsClient.check: {offsets=}, {sizes=}")
def get_size(self) -> int:
return self.size
def close(self) -> None:
deregister_fd(self.file)
os.close(self.file)
del self.ior_r
del self.ior_w
del self.iov_r
del self.iov_w
self.shm_r.close()
self.shm_w.close()
self.shm_r.unlink()
self.shm_w.unlink()
def flush(self) -> None:
os.fsync(self.file)
#include <torch/extension.h>
#include <cstring>
#include <vector>
void read_shm(const torch::Tensor &shm, std::vector<torch::Tensor> dst) {
py::gil_scoped_release release;
char *src_ptr = static_cast<char *>(shm.data_ptr());
size_t current = 0;
for (size_t i = 0; i < dst.size(); ++i) {
auto &t = dst[i];
size_t t_bytes = t.numel() * t.element_size();
char *dst_ptr = static_cast<char *>(t.data_ptr());
std::memcpy(dst_ptr, src_ptr + current, t_bytes);
current += t_bytes;
}
}
void write_shm(const std::vector<torch::Tensor> src, torch::Tensor &shm) {
py::gil_scoped_release release;
char *dst_ptr = static_cast<char *>(shm.data_ptr());
size_t current = 0;
for (size_t i = 0; i < src.size(); ++i) {
auto &t = src[i];
size_t t_bytes = t.numel() * t.element_size();
char *src_ptr = static_cast<char *>(t.data_ptr());
std::memcpy(dst_ptr + current, src_ptr, t_bytes);
current += t_bytes;
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("read_shm", &read_shm, "Read tensors from shared memory");
m.def("write_shm", &write_shm, "Write tensors to shared memory");
}
import atexit
import concurrent.futures
import json
import logging
import os
import signal
import threading
from collections import OrderedDict
from functools import wraps
from typing import List, Optional
import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
logger = logging.getLogger(__name__)
class AtomicCounter:
def __init__(self, n: int):
assert n > 0
self.n = n
self._value = 0
self._lock = threading.Lock()
def next(self) -> int:
with self._lock:
current = self._value
self._value = (current + 1) % self.n
return current
def synchronized():
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.lock:
return func(self, *args, **kwargs)
return wrapper
return _decorator
class HiCacheHF3FS(HiCacheStorage):
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
def __init__(
self,
file_path: str,
file_size: int,
numjobs: int,
bytes_per_page: int,
entries: int,
dtype: torch.dtype,
):
self.file_path = file_path
self.file_size = file_size
self.numjobs = numjobs
self.bytes_per_page = bytes_per_page
self.entries = entries
self.dtype = dtype
self.numel = self.bytes_per_page // self.dtype.itemsize
self.num_pages = self.file_size // self.bytes_per_page
logger.info(
"HiCacheHF3FS "
f"file_path = {self.file_path}, "
f"file_size = {self.file_size/(2**30):.2f} GB, "
f"numjobs = {self.numjobs}, "
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
f"entries = {self.entries}, "
f"num_pages = {self.num_pages}"
)
self.ac = AtomicCounter(self.numjobs)
self.clients = [
Hf3fsClient(
self.file_path, self.file_size, self.bytes_per_page, self.entries
)
for _ in range(numjobs)
]
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
)
# Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
# through centralized metadata orchestration.
self.lock = threading.RLock()
self.free_pages = list(range(self.num_pages))
self.key_to_index = OrderedDict()
atexit.register(self.close)
signal.signal(signal.SIGINT, lambda sig, frame: self.close())
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
@staticmethod
def from_env_config(
rank: int, bytes_per_page: int, dtype: torch.dtype
) -> "HiCacheHF3FS":
config_path = os.getenv(HiCacheHF3FS.default_env_var)
if not config_path:
return HiCacheHF3FS(
file_path=f"/data/hicache.{rank}.bin",
file_size=1 << 40,
numjobs=16,
bytes_per_page=bytes_per_page,
entries=8,
dtype=dtype,
)
try:
with open(config_path, "r") as f:
config = json.load(f)
except Exception as e:
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
required_keys = {
"file_path_prefix",
"file_size",
"numjobs",
"entries",
}
missing_keys = required_keys - set(config.keys())
if missing_keys:
raise ValueError(f"Missing required keys in config: {missing_keys}")
return HiCacheHF3FS(
file_path=f"{config['file_path_prefix']}.{rank}.bin",
file_size=int(config["file_size"]),
numjobs=int(config["numjobs"]),
bytes_per_page=bytes_per_page,
entries=int(config["entries"]),
dtype=dtype,
)
def get(
self, key: str, target_location: Optional[torch.Tensor] = None
) -> torch.Tensor | None:
return self.batch_get([key], target_location)[0]
@synchronized()
def batch_get(
self,
keys: List[str],
target_locations: Optional[List[torch.Tensor]] = None,
) -> List[torch.Tensor | None]:
batch_indices, file_offsets = [], []
for i, key in enumerate(keys):
if key not in self.key_to_index:
continue
batch_indices.append(i)
file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
self.key_to_index.move_to_end(key)
# TODO: target_locations
file_results = [
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
]
futures = [
self.executor.submit(
self.clients[self.ac.next()].batch_read,
file_offsets[i : i + self.entries],
file_results[i : i + self.entries],
)
for i in range(0, len(batch_indices), self.entries)
]
read_results = [result for future in futures for result in future.result()]
results = [None] * len(keys)
for batch_index, file_result, read_result in zip(
batch_indices, file_results, read_results
):
if read_result == self.bytes_per_page:
results[batch_index] = file_result
else:
logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
return results
def set(self, key: str, value: torch.Tensor) -> bool:
return self.batch_set([key], [value])
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
indices = self.get_batch_set_indices(keys)
batch_indices, file_offsets, file_values = [], [], []
for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
if is_written or index == -1:
continue
batch_indices.append(i)
file_offsets.append(index * self.bytes_per_page)
file_values.append(value.contiguous())
futures = [
self.executor.submit(
self.clients[self.ac.next()].batch_write,
file_offsets[i : i + self.entries],
file_values[i : i + self.entries],
)
for i in range(0, len(batch_indices), self.entries)
]
write_results = [
result == self.bytes_per_page
for future in futures
for result in future.result()
]
results = [index[0] for index in indices]
for batch_index, write_result in zip(batch_indices, write_results):
key = keys[batch_index]
index = indices[batch_index][1]
if write_result:
self.key_to_index[key] = index
self.key_to_index.move_to_end(key)
else:
logger.error(f"HiCacheHF3FS set {key} failed")
self.free_pages.append(index)
results[batch_index] = write_result
return all(results)
@synchronized()
def get_batch_set_indices(self, keys: List[str]) -> list:
ionum = len(keys)
# results: tuples of (is_written: bool, page_idx: int)
# - is_written: True = hit (no I/O), False = write (miss)
# - page_idx: page storing data
results = [None] * min(ionum, self.num_pages)
if ionum > self.num_pages:
results.extend([(False, -1)] * (ionum - self.num_pages))
new_keys = []
for batch_index, key in enumerate(keys[: self.num_pages]):
if key in self.key_to_index:
results[batch_index] = (True, self.key_to_index[key])
self.key_to_index.move_to_end(key)
else:
new_keys.append((batch_index, key))
for batch_index, _ in new_keys:
index = (
self.free_pages.pop()
if len(self.free_pages) > 0
else self.key_to_index.popitem(last=False)[1]
)
results[batch_index] = (False, index)
return results
@synchronized()
def delete(self, key: str) -> None:
if key not in self.key_to_index:
return
index = self.key_to_index.pop(key)
self.free_pages.append(index)
@synchronized()
def exists(self, key: str) -> bool:
return key in self.key_to_index
@synchronized()
def clear(self) -> None:
self.free_pages = list(range(self.num_pages))
self.key_to_index.clear()
def close(self) -> None:
try:
for c in self.clients:
c.close()
self.executor.shutdown(wait=True)
except Exception as e:
logger.error(f"close HiCacheHF3FS: {e}")
logger.info("close HiCacheHF3FS")
import multiprocessing.shared_memory
from pathlib import Path
import pytest
import torch
from torch.utils.cpp_extension import load
from tqdm import tqdm
root = Path(__file__).parent.resolve()
hf3fs_utils = load(
name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"], verbose=True
)
def test_rw_shm():
numel = 8 << 20
dtype = torch.bfloat16
page_num = 128
page_bytes = numel * dtype.itemsize
shm = multiprocessing.shared_memory.SharedMemory(
size=page_num * page_bytes, create=True
)
tshm = torch.frombuffer(shm.buf, dtype=torch.uint8)
a = [
torch.randn(numel, dtype=dtype)
for _ in tqdm(range(page_num), desc="prepare input")
]
b = [
torch.empty(numel, dtype=dtype)
for _ in tqdm(range(page_num), desc="prepare output")
]
hf3fs_utils.write_shm(a, tshm)
hf3fs_utils.read_shm(tshm, b)
for _a, _b in tqdm(zip(a, b), desc="assert_close"):
torch.testing.assert_close(_a, _b)
del tshm
shm.close()
shm.unlink()
if __name__ == "__main__":
pytest.main([__file__])
...@@ -1476,7 +1476,7 @@ class ServerArgs: ...@@ -1476,7 +1476,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--hicache-storage-backend", "--hicache-storage-backend",
type=str, type=str,
choices=["file"], # todo, mooncake choices=["file", "hf3fs"], # todo, mooncake
default=ServerArgs.hicache_storage_backend, default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.", help="The storage backend for hierarchical KV cache.",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment