Unverified Commit 58d06fdc authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

[HiCacheStorage]: Improve 3fs kvstore‘s performance and resolve mla issues (#9876)

parent cb9e0e41
...@@ -4,10 +4,12 @@ import json ...@@ -4,10 +4,12 @@ import json
import logging import logging
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, OrderedDict, Tuple
import orjson
import requests import requests
from fastapi import FastAPI, HTTPException, Request, status from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import ORJSONResponse
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry from urllib3.util.retry import Retry
...@@ -24,10 +26,10 @@ class RankMetadata: ...@@ -24,10 +26,10 @@ class RankMetadata:
"""Holds all metadata for a single rank.""" """Holds all metadata for a single rank."""
def __init__(self, num_pages: int): def __init__(self, num_pages: int):
self.lock = threading.RLock() self.lock = threading.Lock()
self.num_pages = num_pages self.num_pages = num_pages
self.free_pages: List[int] = list(range(num_pages)) self.free_pages: List[int] = list(range(num_pages))
self.key_to_index: Dict[str, int] = {} self.key_to_index: OrderedDict[str, int] = OrderedDict()
# Todo: Support multi files for HF3FS # Todo: Support multi files for HF3FS
def exists_keys(self, keys: List[str]) -> List[bool]: def exists_keys(self, keys: List[str]) -> List[bool]:
...@@ -46,16 +48,18 @@ class RankMetadata: ...@@ -46,16 +48,18 @@ class RankMetadata:
for i, (key, prefix_key) in enumerate(keys): for i, (key, prefix_key) in enumerate(keys):
if key in self.key_to_index: if key in self.key_to_index:
results[i] = (True, self.key_to_index[key]) results[i] = (True, self.key_to_index[key])
self.key_to_index.move_to_end(key)
else: else:
new_keys_to_process.append((i, key, prefix_key)) new_keys_to_process.append((i, key, prefix_key))
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through # Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
for i, key, prefix_key in new_keys_to_process: for i, key, prefix_key in new_keys_to_process:
if len(self.free_pages) > 0: if len(self.free_pages) > 0:
page_idx = self.free_pages.pop() page_index = self.free_pages.pop()
results[i] = (False, page_idx)
else: else:
results[i] = (False, -1) page_index = self.key_to_index.popitem(last=False)[1]
results[i] = (False, page_index)
return results return results
...@@ -68,6 +72,7 @@ class RankMetadata: ...@@ -68,6 +72,7 @@ class RankMetadata:
with self.lock: with self.lock:
for key, page_index in written_keys_to_confirm: for key, page_index in written_keys_to_confirm:
self.key_to_index[key] = page_index self.key_to_index[key] = page_index
self.key_to_index.move_to_end(key)
for page_index in pages_to_release: for page_index in pages_to_release:
if page_index not in self.free_pages: if page_index not in self.free_pages:
...@@ -94,7 +99,14 @@ class RankMetadata: ...@@ -94,7 +99,14 @@ class RankMetadata:
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]: def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
"""Get page indices for keys.""" """Get page indices for keys."""
with self.lock: with self.lock:
return [self.key_to_index.get(key) for key in keys] results = []
for key in keys:
if key in self.key_to_index:
results.append(self.key_to_index[key])
self.key_to_index.move_to_end(key)
else:
results.append(None)
return results
class GlobalMetadataState: class GlobalMetadataState:
...@@ -182,7 +194,8 @@ class Hf3fsMetadataServer: ...@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60): def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
self.state = GlobalMetadataState(persistence_path, save_interval) self.state = GlobalMetadataState(persistence_path, save_interval)
self.app = FastAPI() self.app = FastAPI(default_response_class=ORJSONResponse)
self._setup_routes() self._setup_routes()
def _setup_routes(self): def _setup_routes(self):
...@@ -199,17 +212,25 @@ class Hf3fsMetadataServer: ...@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
def get_rank_metadata(self, rank: int) -> RankMetadata: def get_rank_metadata(self, rank: int) -> RankMetadata:
"""Get rank metadata with proper error handling.""" """Get rank metadata with proper error handling."""
with self.state.global_lock:
if rank not in self.state.ranks: if rank not in self.state.ranks:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.", detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
) )
return self.state.ranks[rank] return self.state.ranks[rank]
async def _read_json(self, request: Request) -> dict:
"""Parse request JSON using orjson if available."""
body = await request.body()
return orjson.loads(body)
def _json_response(self, content: dict):
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
return ORJSONResponse(content)
async def initialize(self, rank: int, request: Request): async def initialize(self, rank: int, request: Request):
"""Initialize a rank with specified number of pages.""" """Initialize a rank with specified number of pages."""
data = await request.json() data = await self._read_json(request)
num_pages = data["num_pages"] num_pages = data["num_pages"]
with self.state.global_lock: with self.state.global_lock:
if rank in self.state.ranks: if rank in self.state.ranks:
...@@ -223,57 +244,55 @@ class Hf3fsMetadataServer: ...@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
else: else:
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.") logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
self.state.ranks[rank] = RankMetadata(num_pages) self.state.ranks[rank] = RankMetadata(num_pages)
return {"message": f"Rank {rank} is ready."} return Response(status_code=204)
async def exists(self, rank: int, request: Request): async def exists(self, rank: int, request: Request):
"""Check if keys exist in metadata.""" """Check if keys exist in metadata."""
data = await request.json() data = await self._read_json(request)
keys = data["keys"] keys = data["keys"]
metadata = self.get_rank_metadata(rank) metadata = self.get_rank_metadata(rank)
results = metadata.exists_keys(keys) results = metadata.exists_keys(keys)
return {"exists": results} return self._json_response({"exists": results})
async def reserve_and_allocate_page_indices(self, rank: int, request: Request): async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
"""Reserve and allocate page indices for keys.""" """Reserve and allocate page indices for keys."""
data = await request.json() data = await self._read_json(request)
metadata = self.get_rank_metadata(rank) metadata = self.get_rank_metadata(rank)
keys = data["keys"] keys = data["keys"]
results = metadata.reserve_and_allocate_page_indices(keys) results = metadata.reserve_and_allocate_page_indices(keys)
return {"indices": results} return self._json_response({"indices": results})
async def confirm_write(self, rank: int, request: Request): async def confirm_write(self, rank: int, request: Request):
"""Confirm write operations and release pages.""" """Confirm write operations and release pages."""
data = await request.json() data = await self._read_json(request)
metadata = self.get_rank_metadata(rank) metadata = self.get_rank_metadata(rank)
success_written_keys = data.get("written_keys_to_confirm", []) success_written_keys = data.get("written_keys_to_confirm", [])
released_pages = data.get("pages_to_release", []) released_pages = data.get("pages_to_release", [])
metadata.confirm_write(success_written_keys, released_pages) metadata.confirm_write(success_written_keys, released_pages)
return { return Response(status_code=204)
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
}
async def delete_keys(self, rank: int, request: Request): async def delete_keys(self, rank: int, request: Request):
"""Delete keys from metadata.""" """Delete keys from metadata."""
data = await request.json() data = await self._read_json(request)
metadata = self.get_rank_metadata(rank) metadata = self.get_rank_metadata(rank)
count = metadata.delete_keys(data["keys"]) count = metadata.delete_keys(data["keys"])
return {"message": f"Rank {rank}: {count} keys deleted."} return Response(status_code=204)
async def clear(self, rank: int): async def clear(self, rank: int):
"""Clear all metadata for a rank.""" """Clear all metadata for a rank."""
metadata = self.get_rank_metadata(rank) metadata = self.get_rank_metadata(rank)
metadata.clear_all() metadata.clear_all()
return {"message": f"Rank {rank}: Metadata cleared."} return Response(status_code=204)
async def get_page_indices(self, rank: int, request: Request): async def get_page_indices(self, rank: int, request: Request):
"""Get page indices for keys.""" """Get page indices for keys."""
data = await request.json() data = await self._read_json(request)
metadata = self.get_rank_metadata(rank) metadata = self.get_rank_metadata(rank)
keys = data["keys"] keys = data["keys"]
results = metadata.get_page_indices(keys) results = metadata.get_page_indices(keys)
return {"indices": results} return self._json_response({"indices": results})
def run(self, host: str = "0.0.0.0", port: int = 18000): def run(self, host: str = "0.0.0.0", port: int = 18000):
"""Run the metadata server.""" """Run the metadata server."""
...@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface): ...@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
status_forcelist=[500, 502, 503, 504], status_forcelist=[500, 502, 503, 504],
allowed_methods=["GET", "POST"], allowed_methods=["GET", "POST"],
) )
adapter = HTTPAdapter(max_retries=retry_strategy) adapter = HTTPAdapter(
max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
)
self._session.mount("http://", adapter) self._session.mount("http://", adapter)
def _post(self, endpoint: str, json_data: dict) -> dict: def _post(self, endpoint: str, json_data: dict) -> dict:
try: try:
response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data) url = f"{self.base_url}/{endpoint}"
headers = {"Content-Type": "application/json"}
payload = orjson.dumps(json_data) # type: ignore[union-attr]
response = self._session.post(url, data=payload, headers=headers)
response.raise_for_status() response.raise_for_status()
return response.json()
if response.status_code == 204 or not response.content:
return {}
return orjson.loads(response.content) # type: ignore[union-attr]
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logging.error(f"Failed to POST to {endpoint} after retries: {e}") logging.error(f"Failed to POST to {endpoint} after retries: {e}")
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
......
...@@ -113,6 +113,8 @@ def synchronized(): ...@@ -113,6 +113,8 @@ def synchronized():
class HiCacheHF3FS(HiCacheStorage): class HiCacheHF3FS(HiCacheStorage):
"""HiCache backend that stores KV cache pages in HF3FS files."""
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH" default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
def __init__( def __init__(
...@@ -176,15 +178,32 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -176,15 +178,32 @@ class HiCacheHF3FS(HiCacheStorage):
dtype: torch.dtype, dtype: torch.dtype,
storage_config: HiCacheStorageConfig = None, storage_config: HiCacheStorageConfig = None,
) -> "HiCacheHF3FS": ) -> "HiCacheHF3FS":
"""Create a HiCacheHF3FS instance from environment configuration.
Environment:
- Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
- Falls back to a local single-machine config when the env var is not set.
Raises:
ValueError: If MLA Model is requested without global metadata server or required keys are missing.
"""
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import ( from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsGlobalMetadataClient, Hf3fsGlobalMetadataClient,
Hf3fsLocalMetadataClient, Hf3fsLocalMetadataClient,
) )
rank = storage_config.tp_rank if storage_config is not None else 0 if storage_config is not None:
rank, is_mla_model = storage_config.tp_rank, storage_config.is_mla_model
else:
rank, is_mla_model = 0, False
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
config_path = os.getenv(HiCacheHF3FS.default_env_var) config_path = os.getenv(HiCacheHF3FS.default_env_var)
if not config_path: if not config_path:
if is_mla_model:
raise ValueError(mla_unsupported_msg)
return HiCacheHF3FS( return HiCacheHF3FS(
rank=rank, rank=rank,
file_path=f"/data/hicache.{rank}.bin", file_path=f"/data/hicache.{rank}.bin",
...@@ -214,25 +233,27 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -214,25 +233,27 @@ class HiCacheHF3FS(HiCacheStorage):
raise ValueError(f"Missing required keys in config: {missing_keys}") raise ValueError(f"Missing required keys in config: {missing_keys}")
# Choose metadata client based on configuration # Choose metadata client based on configuration
is_mla_model = False if config.get("metadata_server_url"):
if "metadata_server_url" in config and config["metadata_server_url"]:
# Use global metadata client to connect to metadata server # Use global metadata client to connect to metadata server
metadata_server_url = config["metadata_server_url"] metadata_server_url = config["metadata_server_url"]
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url) metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
# Enable MLA optimization only when using the global metadata client
is_mla_model = storage_config.is_mla_model if storage_config else False
logger.info( logger.info(
f"Using global metadata client with server url: {metadata_server_url}" f"Using global metadata client with server url: {metadata_server_url}"
) )
else: else:
# Enable MLA optimization only when using the global metadata client
if is_mla_model:
raise ValueError(mla_unsupported_msg)
# Use local metadata client for single-machine deployment # Use local metadata client for single-machine deployment
metadata_client = Hf3fsLocalMetadataClient() metadata_client = Hf3fsLocalMetadataClient()
rank_for_path = 0 if is_mla_model else rank
return HiCacheHF3FS( return HiCacheHF3FS(
rank=rank, rank=rank,
# Let all ranks use the same file path for MLA model # Let all ranks use the same file path for MLA model
file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin", file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
file_size=int(config["file_size"]), file_size=int(config["file_size"]),
numjobs=int(config["numjobs"]), numjobs=int(config["numjobs"]),
bytes_per_page=bytes_per_page, bytes_per_page=bytes_per_page,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment