Unverified Commit 83871aa1 authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

feat(hicache): Supports 3fs-hicache compatibility with dp-attention (#9372)

parent b1b3f0b3
......@@ -59,7 +59,7 @@ def test():
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
rank = 0
hicache_hf3fs = HiCacheHF3FS.from_env_config(rank, bytes_per_page, dtype)
hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype, rank)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
......
......@@ -269,7 +269,6 @@ class HiCacheController:
HiCacheHF3FS,
)
rank = get_tensor_model_parallel_rank()
if self.mem_pool_host.layout == "page_first":
bytes_per_page = (
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
......@@ -280,7 +279,7 @@ class HiCacheController:
)
dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config(
rank, bytes_per_page, dtype
bytes_per_page, dtype
)
self.get_hash_str = get_hash_str
else:
......
......@@ -13,6 +13,11 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
is_dp_attention_enabled,
)
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
......@@ -103,8 +108,13 @@ class HiCacheFile(HiCacheStorage):
def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if is_dp_attention_enabled():
tp_rank = get_attention_tp_rank()
tp_size = get_attention_tp_size()
else:
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
if not os.path.exists(self.file_path) and tp_rank == 0:
os.makedirs(self.file_path)
......
......@@ -11,6 +11,11 @@ from typing import Any, List, Optional, Tuple
import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
is_dp_attention_enabled,
)
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
......@@ -167,13 +172,20 @@ class HiCacheHF3FS(HiCacheStorage):
@staticmethod
def from_env_config(
rank: int, bytes_per_page: int, dtype: torch.dtype
bytes_per_page: int, dtype: torch.dtype, rank: int = None
) -> "HiCacheHF3FS":
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsGlobalMetadataClient,
Hf3fsLocalMetadataClient,
)
if rank is None:
rank = (
get_attention_tp_rank()
if is_dp_attention_enabled()
else get_tensor_model_parallel_rank()
)
config_path = os.getenv(HiCacheHF3FS.default_env_var)
if not config_path:
return HiCacheHF3FS(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment