Unverified Commit 83871aa1 authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

feat(hicache): Supports 3fs-hicache compatibility with dp-attention (#9372)

parent b1b3f0b3
...@@ -59,7 +59,7 @@ def test(): ...@@ -59,7 +59,7 @@ def test():
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}") raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
rank = 0 rank = 0
hicache_hf3fs = HiCacheHF3FS.from_env_config(rank, bytes_per_page, dtype) hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype, rank)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page assert numel * dtype.itemsize == bytes_per_page
......
...@@ -269,7 +269,6 @@ class HiCacheController: ...@@ -269,7 +269,6 @@ class HiCacheController:
HiCacheHF3FS, HiCacheHF3FS,
) )
rank = get_tensor_model_parallel_rank()
if self.mem_pool_host.layout == "page_first": if self.mem_pool_host.layout == "page_first":
bytes_per_page = ( bytes_per_page = (
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
...@@ -280,7 +279,7 @@ class HiCacheController: ...@@ -280,7 +279,7 @@ class HiCacheController:
) )
dtype = mem_pool_host.dtype dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config( self.storage_backend = HiCacheHF3FS.from_env_config(
rank, bytes_per_page, dtype bytes_per_page, dtype
) )
self.get_hash_str = get_hash_str self.get_hash_str = get_hash_str
else: else:
......
...@@ -13,6 +13,11 @@ from sglang.srt.distributed import ( ...@@ -13,6 +13,11 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
is_dp_attention_enabled,
)
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str: def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
...@@ -103,8 +108,13 @@ class HiCacheFile(HiCacheStorage): ...@@ -103,8 +108,13 @@ class HiCacheFile(HiCacheStorage):
def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False): def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
tp_rank = get_tensor_model_parallel_rank() if is_dp_attention_enabled():
tp_size = get_tensor_model_parallel_world_size() tp_rank = get_attention_tp_rank()
tp_size = get_attention_tp_size()
else:
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else "" self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
if not os.path.exists(self.file_path) and tp_rank == 0: if not os.path.exists(self.file_path) and tp_rank == 0:
os.makedirs(self.file_path) os.makedirs(self.file_path)
......
...@@ -11,6 +11,11 @@ from typing import Any, List, Optional, Tuple ...@@ -11,6 +11,11 @@ from typing import Any, List, Optional, Tuple
import torch import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
is_dp_attention_enabled,
)
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
...@@ -167,13 +172,20 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -167,13 +172,20 @@ class HiCacheHF3FS(HiCacheStorage):
@staticmethod @staticmethod
def from_env_config( def from_env_config(
rank: int, bytes_per_page: int, dtype: torch.dtype bytes_per_page: int, dtype: torch.dtype, rank: int = None
) -> "HiCacheHF3FS": ) -> "HiCacheHF3FS":
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import ( from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsGlobalMetadataClient, Hf3fsGlobalMetadataClient,
Hf3fsLocalMetadataClient, Hf3fsLocalMetadataClient,
) )
if rank is None:
rank = (
get_attention_tp_rank()
if is_dp_attention_enabled()
else get_tensor_model_parallel_rank()
)
config_path = os.getenv(HiCacheHF3FS.default_env_var) config_path = os.getenv(HiCacheHF3FS.default_env_var)
if not config_path: if not config_path:
return HiCacheHF3FS( return HiCacheHF3FS(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment