Unverified Commit c04c17ed authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

refactor(hicache): Introduce generic HiCacheStorageConfig for improved...


refactor(hicache): Introduce generic HiCacheStorageConfig for improved configuration management (#9555)
Co-authored-by: default avatarTeng Ma <805522925@qq.com>
parent 16a6d21b
...@@ -57,9 +57,7 @@ def test(): ...@@ -57,9 +57,7 @@ def test():
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}") raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype)
rank = 0
hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype, rank)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page assert numel * dtype.itemsize == bytes_per_page
......
...@@ -22,11 +22,21 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -22,11 +22,21 @@ from typing import TYPE_CHECKING, List, Optional
import torch import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
is_dp_attention_enabled,
)
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -231,6 +241,8 @@ class HiCacheController: ...@@ -231,6 +241,8 @@ class HiCacheController:
io_backend: str = "", io_backend: str = "",
storage_backend: Optional[str] = None, storage_backend: Optional[str] = None,
prefetch_threshold: int = 256, prefetch_threshold: int = 256,
model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None,
): ):
self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
...@@ -248,20 +260,22 @@ class HiCacheController: ...@@ -248,20 +260,22 @@ class HiCacheController:
self.get_hash_str = get_hash_str self.get_hash_str = get_hash_str
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. self.storage_config = self._generate_storage_config(
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) model_name, storage_backend_extra_config
)
# In MLA backend, only one rank needs to backup the KV cache # In MLA backend, only one rank needs to backup the KV cache
self.backup_skip = ( self.backup_skip = (
is_mla_backend self.storage_config.is_mla_model
# todo: for load balancing, decide which rank to backup the KV cache by hash value # todo: for load balancing, decide which rank to backup the KV cache by hash value
and get_tensor_model_parallel_rank() != 0 and self.storage_config.tp_rank != 0
# todo: support other storage backends # todo: support other storage backends
and self.storage_backend_type in ["file", "mooncake"] and self.storage_backend_type in ["file", "mooncake"]
) )
if storage_backend == "file": if storage_backend == "file":
from sglang.srt.mem_cache.hicache_storage import HiCacheFile from sglang.srt.mem_cache.hicache_storage import HiCacheFile
self.storage_backend = HiCacheFile(is_mla_backend=is_mla_backend) self.storage_backend = HiCacheFile(self.storage_config)
elif storage_backend == "nixl": elif storage_backend == "nixl":
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
...@@ -271,7 +285,7 @@ class HiCacheController: ...@@ -271,7 +285,7 @@ class HiCacheController:
MooncakeStore, MooncakeStore,
) )
self.storage_backend = MooncakeStore(is_mla_backend=is_mla_backend) self.storage_backend = MooncakeStore(self.storage_config)
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first" assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs": elif storage_backend == "hf3fs":
...@@ -289,7 +303,7 @@ class HiCacheController: ...@@ -289,7 +303,7 @@ class HiCacheController:
) )
dtype = mem_pool_host.dtype dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config( self.storage_backend = HiCacheHF3FS.from_env_config(
bytes_per_page, dtype bytes_per_page, dtype, self.storage_config
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
...@@ -370,6 +384,40 @@ class HiCacheController: ...@@ -370,6 +384,40 @@ class HiCacheController:
self.prefetch_thread.start() self.prefetch_thread.start()
self.backup_thread.start() self.backup_thread.start()
def _generate_storage_config(
self,
model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None,
):
if is_dp_attention_enabled():
self.tp_rank = get_attention_tp_rank()
self.tp_size = get_attention_tp_size()
else:
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
# Parse extra config JSON if provided
extra_config = None
if storage_backend_extra_config:
try:
import json
extra_config = json.loads(storage_backend_extra_config)
except Exception as e:
logger.error(f"Invalid backend extra config JSON: {e}")
return HiCacheStorageConfig(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
is_mla_model=is_mla_backend,
model_name=model_name,
extra_config=extra_config,
)
def reset(self): def reset(self):
self.stop_event.set() self.stop_event.set()
self.write_thread.join() self.write_thread.join()
......
...@@ -627,6 +627,8 @@ class Scheduler( ...@@ -627,6 +627,8 @@ class Scheduler(
hicache_mem_layout=server_args.hicache_mem_layout, hicache_mem_layout=server_args.hicache_mem_layout,
hicache_storage_backend=server_args.hicache_storage_backend, hicache_storage_backend=server_args.hicache_storage_backend,
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy, hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
model_name=server_args.served_model_name,
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
) )
self.tp_worker.register_hicache_layer_transfer_counter( self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter self.tree_cache.cache_controller.layer_done_counter
......
...@@ -2,6 +2,7 @@ import hashlib ...@@ -2,6 +2,7 @@ import hashlib
import logging import logging
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Optional from typing import Any, List, Optional
import torch import torch
...@@ -9,17 +10,6 @@ import torch ...@@ -9,17 +10,6 @@ import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
is_dp_attention_enabled,
)
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str: def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
hasher = hashlib.sha256() hasher = hashlib.sha256()
...@@ -32,6 +22,15 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str: ...@@ -32,6 +22,15 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
return hasher.hexdigest() return hasher.hexdigest()
@dataclass
class HiCacheStorageConfig:
tp_rank: int
tp_size: int
is_mla_model: bool
model_name: Optional[str]
extra_config: Optional[dict] = None
class HiCacheStorage(ABC): class HiCacheStorage(ABC):
""" """
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache. HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
...@@ -117,18 +116,17 @@ class HiCacheStorage(ABC): ...@@ -117,18 +116,17 @@ class HiCacheStorage(ABC):
class HiCacheFile(HiCacheStorage): class HiCacheFile(HiCacheStorage):
def __init__(self, file_path: str = "/tmp/hicache", is_mla_backend: bool = False): def __init__(
self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache"
):
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
if is_dp_attention_enabled():
tp_rank = get_attention_tp_rank() tp_rank, tp_size, is_mla = (
tp_size = get_attention_tp_size() storage_config.tp_rank,
else: storage_config.tp_size,
tp_rank = get_tensor_model_parallel_rank() storage_config.is_mla_model,
tp_size = get_tensor_model_parallel_world_size()
self.tp_suffix = (
f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla_backend else ""
) )
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
if not os.path.exists(self.file_path) and tp_rank == 0: if not os.path.exists(self.file_path) and tp_rank == 0:
os.makedirs(self.file_path) os.makedirs(self.file_path)
logger.info(f"Created HiCacheFile storage directory at {self.file_path}") logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
......
...@@ -39,6 +39,8 @@ class HiRadixCache(RadixCache): ...@@ -39,6 +39,8 @@ class HiRadixCache(RadixCache):
hicache_mem_layout: str, hicache_mem_layout: str,
hicache_storage_backend: Optional[str] = None, hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort", hicache_storage_prefetch_policy: Optional[str] = "best_effort",
model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None,
): ):
if hicache_io_backend == "direct": if hicache_io_backend == "direct":
...@@ -87,6 +89,8 @@ class HiRadixCache(RadixCache): ...@@ -87,6 +89,8 @@ class HiRadixCache(RadixCache):
io_backend=hicache_io_backend, io_backend=hicache_io_backend,
storage_backend=hicache_storage_backend, storage_backend=hicache_storage_backend,
prefetch_threshold=self.prefetch_threshold, prefetch_threshold=self.prefetch_threshold,
model_name=model_name,
storage_backend_extra_config=storage_backend_extra_config,
) )
# record the nodes with ongoing write through # record the nodes with ongoing write through
......
...@@ -11,12 +11,7 @@ from typing import Any, List, Optional, Tuple ...@@ -11,12 +11,7 @@ from typing import Any, List, Optional, Tuple
import torch import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
is_dp_attention_enabled,
)
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -172,19 +167,16 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -172,19 +167,16 @@ class HiCacheHF3FS(HiCacheStorage):
@staticmethod @staticmethod
def from_env_config( def from_env_config(
bytes_per_page: int, dtype: torch.dtype, rank: int = None bytes_per_page: int,
dtype: torch.dtype,
storage_config: HiCacheStorageConfig = None,
) -> "HiCacheHF3FS": ) -> "HiCacheHF3FS":
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import ( from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsGlobalMetadataClient, Hf3fsGlobalMetadataClient,
Hf3fsLocalMetadataClient, Hf3fsLocalMetadataClient,
) )
if rank is None: rank = storage_config.tp_rank if storage_config is not None else 0
rank = (
get_attention_tp_rank()
if is_dp_attention_enabled()
else get_tensor_model_parallel_rank()
)
config_path = os.getenv(HiCacheHF3FS.default_env_var) config_path = os.getenv(HiCacheHF3FS.default_env_var)
if not config_path: if not config_path:
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import torch import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
...@@ -84,15 +84,7 @@ class MooncakeStoreConfig: ...@@ -84,15 +84,7 @@ class MooncakeStoreConfig:
class MooncakeStore(HiCacheStorage): class MooncakeStore(HiCacheStorage):
def __init__(self, is_mla_backend: bool = False): def __init__(self, storage_config: HiCacheStorageConfig = None):
"""
Initialize MooncakeStore.
Args:
is_mla_backend: If the backend is MLA
"""
self.is_mla_backend = is_mla_backend
try: try:
from mooncake.store import MooncakeDistributedStore from mooncake.store import MooncakeDistributedStore
except ImportError as e: except ImportError as e:
...@@ -123,6 +115,13 @@ class MooncakeStore(HiCacheStorage): ...@@ -123,6 +115,13 @@ class MooncakeStore(HiCacheStorage):
self.warmup() self.warmup()
logger.info("Mooncake store warmup successfully.") logger.info("Mooncake store warmup successfully.")
if storage_config is not None:
self.is_mla_backend = storage_config.is_mla_model
self.local_rank = storage_config.tp_rank
else:
self.is_mla_backend = False
self.local_rank = 0
except ValueError as e: except ValueError as e:
logger.error("Configuration loading failed: %s", e) logger.error("Configuration loading failed: %s", e)
raise raise
...@@ -130,8 +129,6 @@ class MooncakeStore(HiCacheStorage): ...@@ -130,8 +129,6 @@ class MooncakeStore(HiCacheStorage):
logger.error("An error occurred while loading the configuration: %s", exc) logger.error("An error occurred while loading the configuration: %s", exc)
raise raise
self.local_rank = get_tensor_model_parallel_rank()
def warmup(self): def warmup(self):
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
warmup_value = bytes(4 * 1024) # 4 KB warmup_value = bytes(4 * 1024) # 4 KB
......
...@@ -216,6 +216,7 @@ class ServerArgs: ...@@ -216,6 +216,7 @@ class ServerArgs:
hicache_mem_layout: str = "layer_first" hicache_mem_layout: str = "layer_first"
hicache_storage_backend: Optional[str] = None hicache_storage_backend: Optional[str] = None
hicache_storage_prefetch_policy: str = "best_effort" hicache_storage_prefetch_policy: str = "best_effort"
hicache_storage_backend_extra_config: Optional[str] = None
# Double Sparsity # Double Sparsity
enable_double_sparsity: bool = False enable_double_sparsity: bool = False
...@@ -1641,6 +1642,12 @@ class ServerArgs: ...@@ -1641,6 +1642,12 @@ class ServerArgs:
default=ServerArgs.hicache_storage_prefetch_policy, default=ServerArgs.hicache_storage_prefetch_policy,
help="Control when prefetching from the storage backend should stop.", help="Control when prefetching from the storage backend should stop.",
) )
parser.add_argument(
"--hicache-storage-backend-extra-config",
type=str,
default=ServerArgs.hicache_storage_backend_extra_config,
help="A dictionary in JSON string format containing extra configuration for the storage backend.",
)
# Double Sparsity # Double Sparsity
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment