Unverified Commit 8df49455 authored by ziruiliu's avatar ziruiliu Committed by GitHub
Browse files

fix file and object naming scheme in HiCacheNixl to avoid data corruption (#10969)


Signed-off-by: default avatarZirui Liu <ziliu@ddn.com>
parent ee3bd8a1
...@@ -161,7 +161,7 @@ class StorageBackendFactory: ...@@ -161,7 +161,7 @@ class StorageBackendFactory:
if backend_name == "file": if backend_name == "file":
return backend_class(storage_config) return backend_class(storage_config)
elif backend_name == "nixl": elif backend_name == "nixl":
return backend_class() return backend_class(storage_config)
elif backend_name == "mooncake": elif backend_name == "mooncake":
backend = backend_class(storage_config) backend = backend_class(storage_config)
return backend return backend
......
...@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
...@@ -26,7 +26,12 @@ logger = logging.getLogger(__name__) ...@@ -26,7 +26,12 @@ logger = logging.getLogger(__name__)
class HiCacheNixl(HiCacheStorage): class HiCacheNixl(HiCacheStorage):
"""HiCacheNixl provides high-performance storage using NIXL plugins.""" """HiCacheNixl provides high-performance storage using NIXL plugins."""
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"): def __init__(
self,
storage_config: HiCacheStorageConfig,
file_path: str = "/tmp/hicache_storage",
plugin: str = "auto",
):
"""Initialize NIXL storage connector.""" """Initialize NIXL storage connector."""
# Might be better to be unified across HiCache backends and moved to HiCacheController # Might be better to be unified across HiCache backends and moved to HiCacheController
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path) file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
...@@ -36,6 +41,19 @@ class HiCacheNixl(HiCacheStorage): ...@@ -36,6 +41,19 @@ class HiCacheNixl(HiCacheStorage):
else None else None
) )
# Initialize suffix based on storage config
tp_rank, tp_size, model_name, is_mla_model = (
storage_config.tp_rank,
storage_config.tp_size,
storage_config.model_name,
storage_config.is_mla_model,
)
model_name = "-".join(model_name.split("/")) if model_name else ""
if is_mla_model:
self.config_suffix = f"_{model_name}"
else:
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
agent_config = nixl_agent_config(backends=[]) agent_config = nixl_agent_config(backends=[])
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}" self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
self.agent = nixl_agent(self.agent_name, agent_config) self.agent = nixl_agent(self.agent_name, agent_config)
...@@ -46,6 +64,9 @@ class HiCacheNixl(HiCacheStorage): ...@@ -46,6 +64,9 @@ class HiCacheNixl(HiCacheStorage):
self.registration = NixlRegistration(self.agent) self.registration = NixlRegistration(self.agent)
def _get_suffixed_key(self, key: str) -> str:
return key + self.config_suffix
def register_buffers( def register_buffers(
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]] self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
) -> Optional[Any]: ) -> Optional[Any]:
...@@ -194,11 +215,14 @@ class HiCacheNixl(HiCacheStorage): ...@@ -194,11 +215,14 @@ class HiCacheNixl(HiCacheStorage):
else: else:
dest = target_locations dest = target_locations
# Add suffix to keys
suffixed_keys = [self._get_suffixed_key(key) for key in keys]
if self.backend_selector.mem_type == "FILE": if self.backend_selector.mem_type == "FILE":
file_paths = [self.file_manager.get_file_path(key) for key in keys] file_paths = [self.file_manager.get_file_path(key) for key in suffixed_keys]
success = self._execute_transfer(dest, file_paths, "READ") success = self._execute_transfer(dest, file_paths, "READ")
else: else:
success = self._execute_transfer(dest, keys, "READ") success = self._execute_transfer(dest, suffixed_keys, "READ")
return target_locations if success and not target_sizes else [None] * len(keys) return target_locations if success and not target_sizes else [None] * len(keys)
def set( def set(
...@@ -227,9 +251,12 @@ class HiCacheNixl(HiCacheStorage): ...@@ -227,9 +251,12 @@ class HiCacheNixl(HiCacheStorage):
if not values: if not values:
values = list(zip(target_locations, target_sizes)) values = list(zip(target_locations, target_sizes))
# Add suffix to keys
suffixed_keys = [self._get_suffixed_key(key) for key in keys]
if self.backend_selector.mem_type == "FILE": if self.backend_selector.mem_type == "FILE":
file_paths = [] file_paths = []
for key in keys: for key in suffixed_keys:
file_path = self.file_manager.get_file_path(key) file_path = self.file_manager.get_file_path(key)
# New file per set, to be updated when partial writes is added to HiCache # New file per set, to be updated when partial writes is added to HiCache
if not self.file_manager.create_file(file_path): if not self.file_manager.create_file(file_path):
...@@ -238,11 +265,14 @@ class HiCacheNixl(HiCacheStorage): ...@@ -238,11 +265,14 @@ class HiCacheNixl(HiCacheStorage):
file_paths.append(file_path) file_paths.append(file_path)
return self._execute_transfer(values, file_paths, "WRITE") return self._execute_transfer(values, file_paths, "WRITE")
else: # mem_type == "OBJ" else: # mem_type == "OBJ"
return self._execute_transfer(values, keys, "WRITE") return self._execute_transfer(values, suffixed_keys, "WRITE")
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
# Add suffix to key
suffixed_key = self._get_suffixed_key(key)
tuples = self.registration.create_query_tuples( tuples = self.registration.create_query_tuples(
key, suffixed_key,
self.backend_selector.mem_type, self.backend_selector.mem_type,
self.file_manager if self.backend_selector.mem_type == "FILE" else None, self.file_manager if self.backend_selector.mem_type == "FILE" else None,
) )
......
...@@ -7,6 +7,7 @@ from unittest.mock import MagicMock ...@@ -7,6 +7,7 @@ from unittest.mock import MagicMock
import torch import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
from sglang.srt.mem_cache.storage.nixl.nixl_utils import ( from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
NixlFileManager, NixlFileManager,
...@@ -31,8 +32,22 @@ class TestNixlUnified(unittest.TestCase): ...@@ -31,8 +32,22 @@ class TestNixlUnified(unittest.TestCase):
# Create instances # Create instances
self.file_manager = NixlFileManager(self.test_dir) self.file_manager = NixlFileManager(self.test_dir)
self.registration = NixlRegistration(self.mock_agent) self.registration = NixlRegistration(self.mock_agent)
# Create storage config for testing
self.storage_config = HiCacheStorageConfig(
tp_rank=0,
tp_size=2,
is_mla_model=False,
is_page_first_layout=False,
model_name="test_model",
)
try: try:
self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX") self.hicache = HiCacheNixl(
storage_config=self.storage_config,
file_path=self.test_dir,
plugin="POSIX",
)
except ImportError: except ImportError:
self.skipTest("NIXL not available, skipping NIXL storage tests") self.skipTest("NIXL not available, skipping NIXL storage tests")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment