Unverified Commit 1b7afad0 authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

feature(hicache): Support hf3fs-hicache reusing kvcache across different instances (#8673)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent f29aba8c
# HF3FS as L3 KV Cache
This document describes how to use deepseek-hf3fs as the L3 KV cache for SGLang.
## Step1: Install deepseek-3fs by 3fs-Operator (Coming Soon)
## Step2: Setup usrbio client
Please follow the document [setup_usrbio_client.md](setup_usrbio_client.md) to setup usrbio client.
## Step3: Deployment
### Single node deployment
```bash
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
python3 -m sglang.launch_server \
--model-path /code/models/Qwen3-32B/ \
--host 0.0.0.0 --port 10000 \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2 --hicache-size 0 \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs
```
### Multi nodes deployment to share KV cache
Please follow the document [deploy_sglang_3fs_multinode.md](deploy_sglang_3fs_multinode.md) to deploy SGLang with 3FS on multiple nodes to share KV cache.
# 1. Startup 3fs metadata service
```bash
nohup python3 -m sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server > meta.out &
```
# 2. Startup sglang engine
## HF3fs configures
```bash
vim /sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json
{
"file_path_prefix": "/data/hicache",
"file_size": 1099511627776,
"numjobs": 16,
"entries": 8,
"metadata_server_url": "http://metaServerIp:18000"
}
```
## node1
```bash
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
rm -rf instance1.out && \
nohup python3 -m sglang.launch_server \
--model-path /code/models/Qwen3-32B/ \
--host 0.0.0.0 --port 10000 \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2 --hicache-size 0 \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs --tp 2 > instance1.out &
```
## node2
```bash
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs_config.json
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
rm -rf instance2.out && \
nohup python3 -m sglang.launch_server \
--model-path /code/models/Qwen3-32B/ \
--host 0.0.0.0 --port 10000 \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2 --hicache-size 0 \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs --tp 2 > instance2.out &
```
# 3. Startup router
```bash
rm -rf router.out && \
nohup python -m sglang_router.launch_router --worker-urls http://node1:10000 http://node2:10000 > router.out &
```
# 4. Startup multiturn benchmark
```bash
rm -rf bench_multiturn.out && \
nohup python3 benchmark/hicache/bench_multiturn.py \
--model-path /code/models/Qwen3-32B \
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
--port 30000 \
--request-length 2048 --num-clients 512 --num-rounds 5 --max-parallel 8 \
> bench_multiturn.out &
```
import argparse
import atexit
import json
import logging
import threading
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import requests
from fastapi import FastAPI, HTTPException, Request, status
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import Hf3fsMetadataInterface
# --- Configuration ---
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# --- Data Models ---
class RankMetadata:
"""Holds all metadata for a single rank."""
def __init__(self, num_pages: int):
self.lock = threading.RLock()
self.num_pages = num_pages
self.free_pages: List[int] = list(range(num_pages))
self.key_to_index: Dict[str, int] = {}
# Todo: Support multi files for HF3FS
def exists_keys(self, keys: List[str]) -> List[bool]:
"""Check if keys exist in metadata."""
with self.lock:
return [key in self.key_to_index for key in keys]
def reserve_and_allocate_page_indices(
self, keys: List[Tuple[str, str]]
) -> List[Tuple[bool, int]]:
"""Reserve and allocate page indices for keys."""
with self.lock:
results = [None] * len(keys)
new_keys_to_process = []
for i, (key, prefix_key) in enumerate(keys):
if key in self.key_to_index:
results[i] = (True, self.key_to_index[key])
else:
new_keys_to_process.append((i, key, prefix_key))
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
for i, key, prefix_key in new_keys_to_process:
if len(self.free_pages) > 0:
page_idx = self.free_pages.pop()
results[i] = (False, page_idx)
else:
results[i] = (False, -1)
return results
def confirm_write(
self,
written_keys_to_confirm: List[Tuple[str, int]],
pages_to_release: List[int],
) -> None:
"""Confirm write operations and release pages."""
with self.lock:
for key, page_index in written_keys_to_confirm:
self.key_to_index[key] = page_index
for page_index in pages_to_release:
if page_index not in self.free_pages:
self.free_pages.append(page_index)
def delete_keys(self, keys: List[str]) -> int:
"""Delete keys and return count of deleted keys."""
with self.lock:
count = 0
for key in keys:
if key in self.key_to_index:
page_index = self.key_to_index.pop(key)
if page_index not in self.free_pages:
self.free_pages.append(page_index)
count += 1
return count
def clear_all(self) -> None:
"""Clear all metadata."""
with self.lock:
self.free_pages = list(range(self.num_pages))
self.key_to_index.clear()
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
"""Get page indices for keys."""
with self.lock:
return [self.key_to_index.get(key) for key in keys]
class GlobalMetadataState:
"""Manages the state for all ranks and persistence."""
def __init__(self, persistence_path: Optional[str], save_interval: int):
self.global_lock = threading.RLock()
self.ranks: Dict[int, RankMetadata] = {}
self.persistence_path = Path(persistence_path) if persistence_path else None
self.save_interval = save_interval
self.save_timer: Optional[threading.Timer] = None
self.is_shutting_down = False
def load_from_disk(self):
if not self.persistence_path or not self.persistence_path.exists():
logging.info("Persistence file not found. Starting with a clean state.")
return
logging.info(f"Loading state from {self.persistence_path}")
try:
with open(self.persistence_path, "r") as f:
persisted_data = json.load(f)
with self.global_lock:
for rank_id_str, data in persisted_data.items():
rank_id = int(rank_id_str)
num_pages = data["num_pages"]
rank_meta = RankMetadata(num_pages)
rank_meta.free_pages = data["free_pages"]
rank_meta.key_to_index = dict(data["key_to_index"])
self.ranks[rank_id] = rank_meta
logging.info(
f"Successfully loaded metadata for {len(self.ranks)} ranks."
)
except (json.JSONDecodeError, KeyError, TypeError) as e:
logging.error(
f"Failed to load or parse persistence file: {e}. Starting fresh.",
exc_info=True,
)
self.ranks.clear()
def save_to_disk(self):
if not self.persistence_path:
return
logging.info("Persisting metadata to disk...")
with self.global_lock:
serializable_state = {}
for rank_id, rank_meta in self.ranks.items():
with rank_meta.lock:
serializable_state[rank_id] = {
"num_pages": rank_meta.num_pages,
"free_pages": rank_meta.free_pages,
"key_to_index": list(rank_meta.key_to_index.items()),
}
try:
temp_path = self.persistence_path.with_suffix(".tmp")
with open(temp_path, "w") as f:
json.dump(serializable_state, f, indent=4)
temp_path.rename(self.persistence_path)
logging.info(f"Metadata successfully persisted to {self.persistence_path}")
except Exception as e:
logging.error(f"Failed to save metadata to disk: {e}", exc_info=True)
def schedule_save(self):
if self.is_shutting_down or not self.persistence_path:
return
self.save_to_disk()
self.save_timer = threading.Timer(self.save_interval, self.schedule_save)
self.save_timer.start()
def shutdown(self):
logging.info("Shutting down metadata server...")
self.is_shutting_down = True
if self.save_timer:
self.save_timer.cancel()
self.save_to_disk()
logging.info("Shutdown complete.")
# --- Global MetadataServer implementation ---
class Hf3fsMetadataServer:
"""HF3FS Metadata Server that manages metadata for multiple ranks."""
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
self.state = GlobalMetadataState(persistence_path, save_interval)
self.app = FastAPI()
self._setup_routes()
def _setup_routes(self):
"""Setup FastAPI routes."""
self.app.post("/{rank}/initialize")(self.initialize)
self.app.post("/{rank}/exists")(self.exists)
self.app.post("/{rank}/reserve_and_allocate_page_indices")(
self.reserve_and_allocate_page_indices
)
self.app.post("/{rank}/confirm_write")(self.confirm_write)
self.app.post("/{rank}/delete_keys")(self.delete_keys)
self.app.post("/{rank}/clear")(self.clear)
self.app.post("/{rank}/get_page_indices")(self.get_page_indices)
def get_rank_metadata(self, rank: int) -> RankMetadata:
"""Get rank metadata with proper error handling."""
with self.state.global_lock:
if rank not in self.state.ranks:
raise HTTPException(
status_code=404,
detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
)
return self.state.ranks[rank]
async def initialize(self, rank: int, request: Request):
"""Initialize a rank with specified number of pages."""
data = await request.json()
num_pages = data["num_pages"]
with self.state.global_lock:
if rank in self.state.ranks:
logging.info(
f"Rank {rank} already exists. Initialization request ignored."
)
if self.state.ranks[rank].num_pages != num_pages:
logging.warning(
f"Rank {rank} initialized with different num_pages. Existing: {self.state.ranks[rank].num_pages}, New: {num_pages}"
)
else:
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
self.state.ranks[rank] = RankMetadata(num_pages)
return {"message": f"Rank {rank} is ready."}
async def exists(self, rank: int, request: Request):
"""Check if keys exist in metadata."""
data = await request.json()
keys = data["keys"]
metadata = self.get_rank_metadata(rank)
results = metadata.exists_keys(keys)
return {"exists": results}
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
"""Reserve and allocate page indices for keys."""
data = await request.json()
metadata = self.get_rank_metadata(rank)
keys = data["keys"]
results = metadata.reserve_and_allocate_page_indices(keys)
return {"indices": results}
async def confirm_write(self, rank: int, request: Request):
"""Confirm write operations and release pages."""
data = await request.json()
metadata = self.get_rank_metadata(rank)
success_written_keys = data.get("written_keys_to_confirm", [])
released_pages = data.get("pages_to_release", [])
metadata.confirm_write(success_written_keys, released_pages)
return {
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
}
async def delete_keys(self, rank: int, request: Request):
"""Delete keys from metadata."""
data = await request.json()
metadata = self.get_rank_metadata(rank)
count = metadata.delete_keys(data["keys"])
return {"message": f"Rank {rank}: {count} keys deleted."}
async def clear(self, rank: int):
"""Clear all metadata for a rank."""
metadata = self.get_rank_metadata(rank)
metadata.clear_all()
return {"message": f"Rank {rank}: Metadata cleared."}
async def get_page_indices(self, rank: int, request: Request):
"""Get page indices for keys."""
data = await request.json()
metadata = self.get_rank_metadata(rank)
keys = data["keys"]
results = metadata.get_page_indices(keys)
return {"indices": results}
def run(self, host: str = "0.0.0.0", port: int = 18000):
"""Run the metadata server."""
self.state.load_from_disk()
if self.state.persistence_path:
self.state.schedule_save()
atexit.register(self.state.shutdown)
import uvicorn
logging.info(f"Starting metadata server on http://{host}:{port}")
if self.state.persistence_path:
logging.info(
f"Persistence is ENABLED. Saving to '{self.state.persistence_path}' every {self.state.save_interval} seconds."
)
else:
logging.info("Persistence is DISABLED.")
uvicorn.run(self.app, host=host, port=port)
# --- Client implementation ---
class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
"""Global http metadata client for HF3FS."""
def __init__(self, base_url: str, max_retries: int = 3):
self.base_url = base_url.rstrip("/")
self._session = requests.Session()
retry_strategy = Retry(
total=max_retries,
backoff_factor=0.3,
status_forcelist=[500, 502, 503, 504],
allowed_methods=["GET", "POST"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self._session.mount("http://", adapter)
def _post(self, endpoint: str, json_data: dict) -> dict:
try:
response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
def initialize(self, rank: int, num_pages: int) -> None:
self._post(f"{rank}/initialize", {"num_pages": num_pages})
def reserve_and_allocate_page_indices(
self, rank: int, keys: List[Tuple[str, str]]
) -> List[Tuple[bool, int]]:
response = self._post(
f"{rank}/reserve_and_allocate_page_indices", {"keys": keys}
)
return [tuple(item) for item in response.get("indices")]
def confirm_write(
self,
rank: int,
written_keys_to_confirm: List[Tuple[str, int]],
pages_to_release: List[int],
) -> None:
self._post(
f"{rank}/confirm_write",
{
"written_keys_to_confirm": written_keys_to_confirm,
"pages_to_release": pages_to_release,
},
)
def delete_keys(self, rank: int, keys: List[str]) -> None:
self._post(f"{rank}/delete_keys", {"keys": keys})
def exists(self, rank: int, keys: List[str]) -> List[bool]:
response = self._post(f"{rank}/exists", {"keys": keys})
return response.get("exists", [])
def clear(self, rank: int) -> None:
self._post(f"{rank}/clear", {})
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
response = self._post(f"{rank}/get_page_indices", {"keys": keys})
return response.get("indices")
class Hf3fsLocalMetadataClient(Hf3fsMetadataInterface):
"""Local metadata client that directly operates on single RankMetadata in memory without metadata server."""
def __init__(self):
self.rank_metadata = None
def initialize(self, rank: int, num_pages: int) -> None:
self.rank_metadata = RankMetadata(num_pages)
def reserve_and_allocate_page_indices(
self, rank: int, keys: List[Tuple[str, str]]
) -> List[Tuple[bool, int]]:
"""Reserve and allocate page indices for keys."""
return self.rank_metadata.reserve_and_allocate_page_indices(keys)
def confirm_write(
self,
rank: int,
written_keys_to_confirm: List[Tuple[str, int]],
pages_to_release: List[int],
) -> None:
"""Confirm write operations."""
self.rank_metadata.confirm_write(written_keys_to_confirm, pages_to_release)
def delete_keys(self, rank: int, keys: List[str]) -> None:
"""Delete keys."""
self.rank_metadata.delete_keys(keys)
def exists(self, rank: int, keys: List[str]) -> List[bool]:
"""Check if keys exist."""
return self.rank_metadata.exists_keys(keys)
def clear(self, rank: int) -> None:
"""Clear all metadata for rank."""
self.rank_metadata.clear_all()
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
"""Get page indices for keys."""
return self.rank_metadata.get_page_indices(keys)
def run_metadata_server(
host: str = "0.0.0.0",
port: int = 18000,
persistence_path: Optional[str] = None,
save_interval: int = 60,
):
"""Run the HF3FS metadata server."""
global server
server = Hf3fsMetadataServer(
persistence_path=persistence_path, save_interval=save_interval
)
server.run(host=host, port=port)
# --- Main Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="HF3FS Metadata Server")
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host to bind the server to."
)
parser.add_argument(
"--port", type=int, default=18000, help="Port to run the server on."
)
parser.add_argument(
"--persistence-path",
type=str,
default=None,
help="Path to the file for persisting metadata. If not provided, persistence is disabled.",
)
parser.add_argument(
"--save-interval",
type=int,
default=60,
help="Interval in seconds for periodically saving metadata to disk.",
)
args = parser.parse_args()
run_metadata_server(args.host, args.port, args.persistence_path, args.save_interval)
...@@ -5,9 +5,9 @@ import logging ...@@ -5,9 +5,9 @@ import logging
import os import os
import signal import signal
import threading import threading
from collections import OrderedDict from abc import ABC, abstractmethod
from functools import wraps from functools import wraps
from typing import List, Optional from typing import List, Optional, Tuple
import torch import torch
...@@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient ...@@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Hf3fsMetadataInterface(ABC):
"""Interface for HF3FS metadata operations."""
@abstractmethod
def initialize(self, rank: int, num_pages: int) -> None:
"""Initialize the metadata service with specified number of pages."""
pass
@abstractmethod
def reserve_and_allocate_page_indices(
self,
rank: int,
keys: List[Tuple[str, str]],
) -> List[Tuple[bool, int]]:
"""
Reserve and allocate page indices for the specified keys.
Args:
rank: The rank of the process.
keys: The keys to reserve and allocate page indices for. Each tuple contains a key and the key of its prefix block.
Returns:
List[Tuple[bool, int]]: A list of tuples, where each tuple contains a boolean indicating whether the key has existed and an integer indicating the allocated page index.
"""
pass
@abstractmethod
def confirm_write(
self,
rank: int,
written_keys_to_confirm: List[Tuple[str, int]],
pages_to_release: List[int],
) -> None:
"""
Confirm that key-value pairs have been successfully written to storage.
Args:
rank: The rank of the process.
written_keys_to_confirm: A list of tuples, where each tuple contains a key and its corresponding page index.
pages_to_release: A list of page indices to be released.
"""
pass
@abstractmethod
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
"""
Get page indices for the specified keys.
Args:
rank: The rank of the process.
keys: A list of keys.
Returns:
List[Optional[int]]: A list of integers representing the page indices for the specified keys.
If a key is not found, the corresponding index will be None.
"""
pass
@abstractmethod
def delete_keys(self, rank: int, keys: List[str]) -> None:
"""Delete specified keys and their associated pages."""
pass
@abstractmethod
def exists(self, rank: int, keys: List[str]) -> List[bool]:
"""Check if the specified keys exist."""
pass
@abstractmethod
def clear(self, rank: int) -> None:
"""Clear all key-value pairs and page allocations for the specified rank."""
pass
class AtomicCounter: class AtomicCounter:
def __init__(self, n: int): def __init__(self, n: int):
assert n > 0 assert n > 0
...@@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage):
def __init__( def __init__(
self, self,
rank: int,
file_path: str, file_path: str,
file_size: int, file_size: int,
numjobs: int, numjobs: int,
bytes_per_page: int, bytes_per_page: int,
entries: int, entries: int,
dtype: torch.dtype, dtype: torch.dtype,
metadata_client: Hf3fsMetadataInterface,
): ):
self.rank = rank
self.file_path = file_path self.file_path = file_path
self.file_size = file_size self.file_size = file_size
self.numjobs = numjobs self.numjobs = numjobs
self.bytes_per_page = bytes_per_page self.bytes_per_page = bytes_per_page
self.entries = entries self.entries = entries
self.dtype = dtype self.dtype = dtype
self.metadata_client = metadata_client
self.numel = self.bytes_per_page // self.dtype.itemsize self.numel = self.bytes_per_page // self.dtype.itemsize
self.num_pages = self.file_size // self.bytes_per_page self.num_pages = self.file_size // self.bytes_per_page
logger.info( logger.info(
"HiCacheHF3FS " f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
f"file_path = {self.file_path}, " f"file_path={self.file_path}, "
f"file_size = {self.file_size/(2**30):.2f} GB, " f"file_size={self.file_size / (2 ** 30):.2f} GB, "
f"numjobs = {self.numjobs}, " f"num_pages={self.num_pages}"
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
f"entries = {self.entries}, "
f"num_pages = {self.num_pages}"
) )
self.ac = AtomicCounter(self.numjobs) self.ac = AtomicCounter(self.numjobs)
...@@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage):
for _ in range(numjobs) for _ in range(numjobs)
] ]
self.executor = concurrent.futures.ThreadPoolExecutor( self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS" max_workers=self.numjobs, thread_name_prefix=f"HiCacheHF3FS-Rank{self.rank}"
) )
# Implemented a preliminary single-file page_hash -> file_offset index as interim storage. self.metadata_client.initialize(self.rank, self.num_pages)
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
# through centralized metadata orchestration.
self.lock = threading.RLock() self.lock = threading.RLock()
self.free_pages = list(range(self.num_pages))
self.key_to_index = OrderedDict()
atexit.register(self.close) atexit.register(self.close)
...@@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage):
def from_env_config( def from_env_config(
rank: int, bytes_per_page: int, dtype: torch.dtype rank: int, bytes_per_page: int, dtype: torch.dtype
) -> "HiCacheHF3FS": ) -> "HiCacheHF3FS":
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsGlobalMetadataClient,
Hf3fsLocalMetadataClient,
)
config_path = os.getenv(HiCacheHF3FS.default_env_var) config_path = os.getenv(HiCacheHF3FS.default_env_var)
if not config_path: if not config_path:
return HiCacheHF3FS( return HiCacheHF3FS(
rank=rank,
file_path=f"/data/hicache.{rank}.bin", file_path=f"/data/hicache.{rank}.bin",
file_size=1 << 40, file_size=1 << 40,
numjobs=16, numjobs=16,
bytes_per_page=bytes_per_page, bytes_per_page=bytes_per_page,
entries=8, entries=8,
dtype=dtype, dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
) )
try: try:
...@@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage):
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}") raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
# Check required keys (metadata_server_url is now optional)
required_keys = { required_keys = {
"file_path_prefix", "file_path_prefix",
"file_size", "file_size",
...@@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage):
if missing_keys: if missing_keys:
raise ValueError(f"Missing required keys in config: {missing_keys}") raise ValueError(f"Missing required keys in config: {missing_keys}")
# Choose metadata client based on configuration
if "metadata_server_url" in config and config["metadata_server_url"]:
# Use global metadata client to connect to metadata server
metadata_server_url = config["metadata_server_url"]
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
logger.info(
f"Using global metadata client with server url: {metadata_server_url}"
)
else:
# Use local metadata client for single-machine deployment
metadata_client = Hf3fsLocalMetadataClient()
return HiCacheHF3FS( return HiCacheHF3FS(
rank=rank,
file_path=f"{config['file_path_prefix']}.{rank}.bin", file_path=f"{config['file_path_prefix']}.{rank}.bin",
file_size=int(config["file_size"]), file_size=int(config["file_size"]),
numjobs=int(config["numjobs"]), numjobs=int(config["numjobs"]),
bytes_per_page=bytes_per_page, bytes_per_page=bytes_per_page,
entries=int(config["entries"]), entries=int(config["entries"]),
dtype=dtype, dtype=dtype,
metadata_client=metadata_client,
) )
def get( def get(
self, key: str, target_location: Optional[torch.Tensor] = None self, key: str, target_location: Optional[torch.Tensor] = None
) -> torch.Tensor | None: ) -> torch.Tensor | None:
return self.batch_get([key], target_location)[0] return self.batch_get([key], [target_location] if target_location else None)[0]
@synchronized() @synchronized()
def batch_get( def batch_get(
...@@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage):
keys: List[str], keys: List[str],
target_locations: Optional[List[torch.Tensor]] = None, target_locations: Optional[List[torch.Tensor]] = None,
) -> List[torch.Tensor | None]: ) -> List[torch.Tensor | None]:
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
batch_indices, file_offsets = [], [] batch_indices, file_offsets = [], []
for i, key in enumerate(keys): for i, page_index in enumerate(page_indices):
if key not in self.key_to_index: if page_index is not None:
continue batch_indices.append(i)
batch_indices.append(i) file_offsets.append(page_index * self.bytes_per_page)
file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
self.key_to_index.move_to_end(key)
# TODO: target_locations
file_results = [ file_results = [
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices)) torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
] ]
...@@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage):
if read_result == self.bytes_per_page: if read_result == self.bytes_per_page:
results[batch_index] = file_result results[batch_index] = file_result
else: else:
logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed") logger.error(
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
)
return results return results
...@@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage):
return self.batch_set([key], [value]) return self.batch_set([key], [value])
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
indices = self.get_batch_set_indices(keys) # Todo: Add prefix block's hash key
key_with_prefix = [(key, "") for key in keys]
indices = self.metadata_client.reserve_and_allocate_page_indices(
self.rank, key_with_prefix
)
batch_indices, file_offsets, file_values = [], [], [] batch_indices, file_offsets, file_values = [], [], []
for i, (value, (is_written, index)) in enumerate(zip(values, indices)): pages_to_release = []
if is_written or index == -1:
for i, (value, (is_written, page_index)) in enumerate(zip(values, indices)):
if is_written or page_index == -1:
continue continue
batch_indices.append(i) batch_indices.append(i)
file_offsets.append(index * self.bytes_per_page) file_offsets.append(page_index * self.bytes_per_page)
file_values.append(value.contiguous()) file_values.append(value.contiguous())
futures = [ futures = [
...@@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage):
for result in future.result() for result in future.result()
] ]
written_keys_to_confirm = []
results = [index[0] for index in indices] results = [index[0] for index in indices]
for batch_index, write_result in zip(batch_indices, write_results): for batch_index, write_result in zip(batch_indices, write_results):
key = keys[batch_index] key = keys[batch_index]
index = indices[batch_index][1] page_index = indices[batch_index][1]
if write_result: if write_result:
self.key_to_index[key] = index written_keys_to_confirm.append((key, page_index))
self.key_to_index.move_to_end(key)
else: else:
logger.error(f"HiCacheHF3FS set {key} failed") logger.error(f"[Rank {self.rank}] HiCacheHF3FS set {key} failed")
self.free_pages.append(index) pages_to_release.append(page_index)
results[batch_index] = write_result results[batch_index] = write_result
return all(results)
@synchronized()
def get_batch_set_indices(self, keys: List[str]) -> list:
ionum = len(keys)
# results: tuples of (is_written: bool, page_idx: int)
# - is_written: True = hit (no I/O), False = write (miss)
# - page_idx: page storing data
results = [None] * min(ionum, self.num_pages)
if ionum > self.num_pages:
results.extend([(False, -1)] * (ionum - self.num_pages))
new_keys = []
for batch_index, key in enumerate(keys[: self.num_pages]):
if key in self.key_to_index:
results[batch_index] = (True, self.key_to_index[key])
self.key_to_index.move_to_end(key)
else:
new_keys.append((batch_index, key))
for batch_index, _ in new_keys: if len(written_keys_to_confirm) > 0 or len(pages_to_release) > 0:
index = ( self.metadata_client.confirm_write(
self.free_pages.pop() self.rank, written_keys_to_confirm, pages_to_release
if len(self.free_pages) > 0
else self.key_to_index.popitem(last=False)[1]
) )
results[batch_index] = (False, index)
return results return all(results)
@synchronized() @synchronized()
def delete(self, key: str) -> None: def delete(self, key: str) -> None:
if key not in self.key_to_index: self.metadata_client.delete_keys(self.rank, [key])
return
index = self.key_to_index.pop(key)
self.free_pages.append(index)
@synchronized() @synchronized()
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
return key in self.key_to_index result = self.metadata_client.exists(self.rank, [key])
return result[0] if result else False
@synchronized() @synchronized()
def clear(self) -> None: def clear(self) -> None:
self.free_pages = list(range(self.num_pages)) self.metadata_client.clear(self.rank)
self.key_to_index.clear()
def close(self) -> None: def close(self) -> None:
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment