Unverified Commit 70cf4abc authored by pansicheng's avatar pansicheng Committed by GitHub
Browse files

3fs zerocopy (#9109)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent cebf4599
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
python3 benchmark/hf3fs/bench_client.py
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \ SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \
python3 benchmark/hf3fs/bench_storage.py python3 benchmark/hf3fs/bench_storage.py
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json
echo '{"file_path_prefix": "/data/hf3fs-test-0", "file_size": 1099511627776, "numjobs": 16, "entries": 8}' > \
${SGLANG_HICACHE_HF3FS_CONFIG_PATH}
python3 benchmark/hf3fs/bench_zerocopy.py
#################################################################################################### ####################################################################################################
rm -rf nohup.out && \ rm -rf nohup.out && \
......
...@@ -8,6 +8,9 @@ from typing import List ...@@ -8,6 +8,9 @@ from typing import List
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsLocalMetadataClient,
)
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
...@@ -67,12 +70,15 @@ def test(): ...@@ -67,12 +70,15 @@ def test():
k = f"key_{i}" k = f"key_{i}"
v = torch.randn((numel,)).to(dtype=dtype) v = torch.randn((numel,)).to(dtype=dtype)
ok = hicache_hf3fs.set(k, v) ok = hicache_hf3fs.set(k, v)
assert ok, f"Failed to insert {k}" if i < (file_size // bytes_per_page):
assert ok, f"Failed to insert {k}"
else:
assert not ok
tensors[k] = v tensors[k] = v
assert hicache_hf3fs.get("key_0") is None assert hicache_hf3fs.get("key_8") is None
assert hicache_hf3fs.get("key_1") is None assert hicache_hf3fs.get("key_9") is None
start = num_pages - hicache_hf3fs.num_pages start = 0
for i in range(start, start + hicache_hf3fs.num_pages): for i in range(start, start + hicache_hf3fs.num_pages):
k = f"key_{i}" k = f"key_{i}"
assert hicache_hf3fs.exists(k) assert hicache_hf3fs.exists(k)
...@@ -83,13 +89,16 @@ def test(): ...@@ -83,13 +89,16 @@ def test():
assert not hicache_hf3fs.exists("not_exists") assert not hicache_hf3fs.exists("not_exists")
hicache_hf3fs.delete("key_9") hicache_hf3fs.delete("key_7")
v2 = torch.randn((numel,)).to(dtype=dtype) v2 = torch.randn((numel,)).to(dtype=dtype)
assert hicache_hf3fs.set("key_new", v2) assert hicache_hf3fs.set("key_new", v2)
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3) assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
hicache_hf3fs.clear() hicache_hf3fs.clear()
assert len(hicache_hf3fs.free_pages) == hicache_hf3fs.num_pages assert (
len(hicache_hf3fs.metadata_client.rank_metadata.free_pages)
== hicache_hf3fs.metadata_client.rank_metadata.num_pages
)
# batch # batch
num_pages = 10 num_pages = 10
...@@ -134,12 +143,14 @@ def bench(): ...@@ -134,12 +143,14 @@ def bench():
entries = 8 entries = 8
dtype = store_dtype dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS( hicache_hf3fs = HiCacheHF3FS(
rank=0,
file_path=file_path, file_path=file_path,
file_size=file_size, file_size=file_size,
numjobs=numjobs, numjobs=numjobs,
bytes_per_page=bytes_per_page, bytes_per_page=bytes_per_page,
entries=entries, entries=entries,
dtype=dtype, dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
) )
numel = 2 * tokens_per_page * layer_num * head_num * head_dim numel = 2 * tokens_per_page * layer_num * head_num * head_dim
...@@ -167,7 +178,10 @@ def bench(): ...@@ -167,7 +178,10 @@ def bench():
r_bw = [] r_bw = []
r_size = num_page * bytes_per_page / (1 << 30) r_size = num_page * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"): for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page) keys = random.sample(
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
num_page,
)
tik = time.perf_counter() tik = time.perf_counter()
results = hicache_hf3fs.batch_get(keys) results = hicache_hf3fs.batch_get(keys)
tok = time.perf_counter() tok = time.perf_counter()
...@@ -195,12 +209,14 @@ def allclose(): ...@@ -195,12 +209,14 @@ def allclose():
entries = 8 entries = 8
dtype = store_dtype dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS( hicache_hf3fs = HiCacheHF3FS(
rank=0,
file_path=file_path, file_path=file_path,
file_size=file_size, file_size=file_size,
numjobs=numjobs, numjobs=numjobs,
bytes_per_page=bytes_per_page, bytes_per_page=bytes_per_page,
entries=entries, entries=entries,
dtype=dtype, dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
) )
numel = 2 * tokens_per_page * layer_num * head_num * head_dim numel = 2 * tokens_per_page * layer_num * head_num * head_dim
...@@ -218,7 +234,10 @@ def allclose(): ...@@ -218,7 +234,10 @@ def allclose():
read_keys, read_results = [], [] read_keys, read_results = [], []
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"): for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page) keys = random.sample(
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
num_page,
)
results = hicache_hf3fs.batch_get(keys) results = hicache_hf3fs.batch_get(keys)
read_keys.extend(keys) read_keys.extend(keys)
read_results.extend(results) read_results.extend(results)
......
import threading
import time
import torch
from tqdm import tqdm
from sglang.srt.distributed import (
get_world_group,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.managers.cache_controller import (
HiCacheController,
PrefetchOperation,
StorageOperation,
)
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
group = get_world_group().cpu_group
max_total_num_tokens = 524288
page_size = 64
kv_cache_dtype = torch.bfloat16
layer_num = 64
head_num, head_dim = 8, 128
device = "cuda"
hicache_ratio = 2
hicache_size = 0
hicache_mem_layout = "page_first"
# hicache_mem_layout = "layer_first"
hicache_write_policy = "write_through"
hicache_io_backend = "kernel"
hicache_storage_backend = "hf3fs"
prefetch_threshold = 256
op_size = 1024
op_num = 16
token_to_kv_pool = MHATokenToKVPool(
max_total_num_tokens,
page_size=page_size,
dtype=kv_cache_dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=layer_num,
device=device,
enable_memory_saver=True,
)
token_to_kv_pool_allocator = TokenToKVPoolAllocator(
max_total_num_tokens,
dtype=kv_cache_dtype,
device=device,
kvcache=token_to_kv_pool,
need_sort=False,
)
kv_cache = token_to_kv_pool_allocator.get_kvcache()
token_to_kv_pool_host = MHATokenToKVPoolHost(
kv_cache,
hicache_ratio,
hicache_size,
page_size,
hicache_mem_layout,
)
load_cache_event = threading.Event()
cache_controller = HiCacheController(
token_to_kv_pool_allocator,
token_to_kv_pool_host,
page_size,
group,
load_cache_event=load_cache_event,
write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
storage_backend=hicache_storage_backend,
prefetch_threshold=prefetch_threshold,
)
operations = [
StorageOperation(
torch.tensor(list(range(i, i + op_size))),
list(range(i, i + op_size)),
hash_value=[f"{j}" for j in range(i, i + op_size, page_size)],
)
for i in tqdm(range(0, op_num * op_size, op_size))
]
tik = time.monotonic()
if hicache_mem_layout == "page_first":
for operation in operations:
cache_controller.zerocopy_page_backup(operation, batch_size=128)
elif hicache_mem_layout == "layer_first":
for operation in operations:
cache_controller.generic_page_backup(operation, batch_size=128)
tok = time.monotonic()
print(f"{tok-tik:.6f} s")
operations = [
PrefetchOperation(
f"{i}",
torch.tensor(list(range(i, i + op_size))),
list(range(i, i + op_size)),
f"{i}",
)
for i in tqdm(range(0, op_num * op_size, op_size))
]
for operation in operations:
operation.hash_value = [
f"{j}"
for j in range(
int(operation.last_hash), int(operation.last_hash) + op_size, page_size
)
]
tik = time.monotonic()
if hicache_mem_layout == "page_first":
for operation in operations:
cache_controller.zerocopy_page_transfer(operation, batch_size=128)
elif hicache_mem_layout == "layer_first":
for operation in operations:
cache_controller.generic_page_transfer(operation, batch_size=128)
tok = time.monotonic()
print(f"{tok-tik:.6f} s")
...@@ -268,9 +268,14 @@ class HiCacheController: ...@@ -268,9 +268,14 @@ class HiCacheController:
) )
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
bytes_per_page = ( if self.mem_pool_host.layout == "page_first":
mem_pool_host.get_size_per_token() * mem_pool_host.page_size bytes_per_page = (
) mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
)
elif self.mem_pool_host.layout == "layer_first":
bytes_per_page = (
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
)
dtype = mem_pool_host.dtype dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config( self.storage_backend = HiCacheHF3FS.from_env_config(
rank, bytes_per_page, dtype rank, bytes_per_page, dtype
...@@ -555,13 +560,34 @@ class HiCacheController: ...@@ -555,13 +560,34 @@ class HiCacheController:
operation.mark_done() operation.mark_done()
return operation.completed_tokens, operation.hash_value return operation.completed_tokens, operation.hash_value
def zerocopy_page_transfer(self, operation, batch_size=8):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
operation.hash_value, operation.host_indices
)
for i in range(0, len(hashes), batch_size):
page_hashes = hashes[i : i + batch_size]
page_dsts = dsts[i : i + batch_size]
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
)
break
completed_tokens = operation.completed_tokens
if operation.increment(self.page_size * len(page_hashes)):
for i in range(len(page_hashes)):
completed_tokens += self.page_size
else:
break
def generic_page_transfer(self, operation, batch_size=8): def generic_page_transfer(self, operation, batch_size=8):
for i in range(0, len(operation.hash_value), batch_size): for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size] page_hashes = operation.hash_value[i : i + batch_size]
# todo: zero copy # todo: zero copy
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len( dummy_page_dst = [
page_hashes self.mem_pool_host.get_dummy_flat_data_page()
) for _ in range(len(page_hashes))
]
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst) page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
if page_data is None: if page_data is None:
logger.warning( logger.warning(
...@@ -599,7 +625,10 @@ class HiCacheController: ...@@ -599,7 +625,10 @@ class HiCacheController:
if self.is_mooncake_backend(): if self.is_mooncake_backend():
self.mooncake_page_transfer(operation) self.mooncake_page_transfer(operation)
elif self.storage_backend_type == "hf3fs": elif self.storage_backend_type == "hf3fs":
self.generic_page_transfer(operation, batch_size=128) if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_transfer(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_transfer(operation, batch_size=128)
else: else:
self.generic_page_transfer(operation) self.generic_page_transfer(operation)
...@@ -716,6 +745,19 @@ class HiCacheController: ...@@ -716,6 +745,19 @@ class HiCacheController:
self.backup_queue.put(operation) self.backup_queue.put(operation)
return operation.id return operation.id
def zerocopy_page_backup(self, operation, batch_size=8):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
operation.hash_value, operation.host_indices
)
for i in range(0, len(hashes), batch_size):
page_hashes = hashes[i : i + batch_size]
page_data = dsts[i : i + batch_size]
success = self.storage_backend.batch_set(page_hashes, page_data)
if not success:
logger.warning(f"Failed to write page {page_hashes} to storage.")
break
operation.completed_tokens += self.page_size * len(page_hashes)
def generic_page_backup(self, operation, batch_size=8): def generic_page_backup(self, operation, batch_size=8):
for i in range(0, len(operation.hash_value), batch_size): for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size] page_hashes = operation.hash_value[i : i + batch_size]
...@@ -770,7 +812,10 @@ class HiCacheController: ...@@ -770,7 +812,10 @@ class HiCacheController:
if self.is_mooncake_backend(): if self.is_mooncake_backend():
self.mooncake_page_backup(operation) self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs": elif self.storage_backend_type == "hf3fs":
self.generic_page_backup(operation, batch_size=128) if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_backup(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_backup(operation, batch_size=128)
else: else:
self.generic_page_backup(operation) self.generic_page_backup(operation)
......
...@@ -307,6 +307,9 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -307,6 +307,9 @@ class MHATokenToKVPoolHost(HostKVCache):
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
def get_ksize_per_token(self):
return self.get_size_per_token() // 2
def init_kv_buffer(self): def init_kv_buffer(self):
if self.layout == "layer_first": if self.layout == "layer_first":
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim) dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
...@@ -496,6 +499,21 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -496,6 +499,21 @@ class MHATokenToKVPoolHost(HostKVCache):
element_size_list = [element_size] * len(key_list) element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices):
assert self.layout == "page_first"
assert len(keys) == (len(indices) // self.page_size)
key_list = []
buf_list = []
for key, i in zip(keys, range(0, len(indices), self.page_size)):
key_list.append(f"{key}-k")
buf_list.append(self.k_buffer[i : i + self.page_size])
key_list.append(f"{key}-v")
buf_list.append(self.v_buffer[i : i + self.page_size])
return key_list, buf_list
class MLATokenToKVPoolHost(HostKVCache): class MLATokenToKVPoolHost(HostKVCache):
device_pool: MLATokenToKVPool device_pool: MLATokenToKVPool
...@@ -538,6 +556,9 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -538,6 +556,9 @@ class MLATokenToKVPoolHost(HostKVCache):
* self.layer_num * self.layer_num
) )
def get_ksize_per_token(self):
return self.get_size_per_token()
def init_kv_buffer(self): def init_kv_buffer(self):
if self.layout == "layer_first": if self.layout == "layer_first":
dims = ( dims = (
...@@ -704,3 +725,14 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -704,3 +725,14 @@ class MLATokenToKVPoolHost(HostKVCache):
) )
element_size_list = [element_size] * len(key_list) element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices):
assert self.layout == "page_first"
assert len(keys) == (len(indices) // self.page_size)
buf_list = []
for i in range(0, len(indices), self.page_size):
buf_list.append(self.kv_buffer[i : i + self.page_size])
return keys, buf_list
...@@ -34,6 +34,9 @@ apt-get update \ ...@@ -34,6 +34,9 @@ apt-get update \
python3 python3-pip \ python3 python3-pip \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# apt install python3.12 python3.12-venv python3.12-dev
# curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
# python3.12 get-pip.py
# Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl # Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
python3 setup.py bdist_wheel python3 setup.py bdist_wheel
...@@ -60,6 +63,6 @@ apt update && apt install -y \ ...@@ -60,6 +63,6 @@ apt update && apt install -y \
libuv1-dev libuv1-dev
# Install Python Package # Install Python Package
pip install hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl pip install hf3fs_py_usrbio-1.2.9+394583d-cp312-cp312-linux_x86_64.whl
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages
``` ```
...@@ -7,7 +7,7 @@ import signal ...@@ -7,7 +7,7 @@ import signal
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import wraps from functools import wraps
from typing import List, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch import torch
...@@ -228,15 +228,23 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -228,15 +228,23 @@ class HiCacheHF3FS(HiCacheStorage):
) )
def get( def get(
self, key: str, target_location: Optional[torch.Tensor] = None self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
return self.batch_get([key], [target_location] if target_location else None)[0] return self.batch_get(
[key],
[target_location] if target_location is not None else None,
[target_sizes] if target_sizes is not None else None,
)[0]
@synchronized() @synchronized()
def batch_get( def batch_get(
self, self,
keys: List[str], keys: List[str],
target_locations: Optional[List[torch.Tensor]] = None, target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]: ) -> List[torch.Tensor | None]:
page_indices = self.metadata_client.get_page_indices(self.rank, keys) page_indices = self.metadata_client.get_page_indices(self.rank, keys)
...@@ -246,9 +254,15 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -246,9 +254,15 @@ class HiCacheHF3FS(HiCacheStorage):
batch_indices.append(i) batch_indices.append(i)
file_offsets.append(page_index * self.bytes_per_page) file_offsets.append(page_index * self.bytes_per_page)
file_results = [ if target_locations is not None:
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices)) for target_location in target_locations:
] assert target_location.is_contiguous()
file_results = target_locations
else:
file_results = [
torch.empty(self.numel, dtype=self.dtype)
for _ in range(len(batch_indices))
]
futures = [ futures = [
self.executor.submit( self.executor.submit(
...@@ -273,10 +287,27 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -273,10 +287,27 @@ class HiCacheHF3FS(HiCacheStorage):
return results return results
def set(self, key: str, value: torch.Tensor) -> bool: def set(
return self.batch_set([key], [value]) self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
return self.batch_set(
[key],
[value] if value is not None else None,
[target_location] if target_location is not None else None,
[target_sizes] if target_sizes is not None else None,
)
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
# Todo: Add prefix block's hash key # Todo: Add prefix block's hash key
key_with_prefix = [(key, "") for key in keys] key_with_prefix = [(key, "") for key in keys]
indices = self.metadata_client.reserve_and_allocate_page_indices( indices = self.metadata_client.reserve_and_allocate_page_indices(
...@@ -292,7 +323,8 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -292,7 +323,8 @@ class HiCacheHF3FS(HiCacheStorage):
batch_indices.append(i) batch_indices.append(i)
file_offsets.append(page_index * self.bytes_per_page) file_offsets.append(page_index * self.bytes_per_page)
file_values.append(value.contiguous()) assert value.is_contiguous()
file_values.append(value)
futures = [ futures = [
self.executor.submit( self.executor.submit(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment