Unverified Commit 7e880286 authored by Moein Khazraee's avatar Moein Khazraee Committed by GitHub
Browse files

Add support for extensions of interface and pre-registrations to NIXL HiCache (#9211)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 446c8e4c
...@@ -36,6 +36,21 @@ Consolidated utility classes: ...@@ -36,6 +36,21 @@ Consolidated utility classes:
- **NixlRegistration** - Manages memory registration for tensors, files and objects - **NixlRegistration** - Manages memory registration for tensors, files and objects
- **NixlFileManager** - Handles file system operations and NIXL tuple creation - **NixlFileManager** - Handles file system operations and NIXL tuple creation
## Using NIXL for HiCache backend
When running the SGLang server, indicate `nixl` for `hicache-storage-backend` parameter, for instance:
```bash
python3 -m sglang.launch_server --model-path <model> --host <ip> --port <port> --page-size 64 --enable-hierarchical-cache --hicache-ratio 2 --hicache-size 64 --hicache-write-policy write_through --hicache-storage-backend nixl
```
To customize the base directory for files, you can set the following environment variable:
```bash
export SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR=/path/to/desired/dir
```
Selection of any storage backend like 3FS requires availability of that library on the system, and the backend is selected based on the priority mentioned above.
## Running Unit Tests ## Running Unit Tests
### Prerequisites ### Prerequisites
...@@ -43,33 +58,26 @@ Consolidated utility classes: ...@@ -43,33 +58,26 @@ Consolidated utility classes:
- PyTorch installed - PyTorch installed
- Python 3.8+ - Python 3.8+
### Unit tests from Project root ### Unit tests from current directory
Navigate to the project root directory (`/path/to/sglang`) and run: From the current directory run:
#### Run all NIXL tests: #### Run all NIXL tests:
```bash ```bash
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -o asyncio_mode=strict PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
``` ```
#### Run with verbose output: #### Run with verbose output:
```bash ```bash
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -o asyncio_mode=strict PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -o asyncio_mode=strict
``` ```
Note: The `-v` flag provides more detailed output, showing each test case name and its result. Note: The `-v` flag provides more detailed output, showing each test case name and its result.
#### Run a specific test: #### Run a specific test:
```bash ```bash
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
``` ```
### From Tests Directory
Navigate to the tests directory and run:
```bash
cd test/srt
PYTHONPATH=../.. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
```
Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output. Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output.
## Test Coverage ## Test Coverage
......
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
import os import os
import time import time
import uuid import uuid
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage): ...@@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage):
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"): def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
"""Initialize NIXL storage connector.""" """Initialize NIXL storage connector."""
# Might be better to be unified across HiCache backends and moved to HiCacheController
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
self.file_manager = ( self.file_manager = (
NixlFileManager(file_path) NixlFileManager(file_path)
if plugin not in NixlBackendSelection.OBJ_PLUGINS if plugin not in NixlBackendSelection.OBJ_PLUGINS
...@@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage): ...@@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage):
self.registration = NixlRegistration(self.agent) self.registration = NixlRegistration(self.agent)
def register_buffers(
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
) -> Optional[Any]:
"""Register tensor(s) or target locations in host memory (list of addr,len tuples) with NIXL."""
if isinstance(buffers[0], tuple):
tuples = [(x[0], x[1], 0, "") for x in buffers]
return self.registration._register_memory(tuples, "DRAM")
else:
return self.registration._register_memory(buffers)
def register_files(
self, file_paths: List[str], open_file: Optional[bool] = True
) -> Optional[Any]:
"""Register files with NIXL."""
tuples = self.file_manager.files_to_nixl_tuples(file_paths)
return self.registration._register_memory(tuples, "FILE")
def register_objects(
self, keys: List[str], sizes: Optional[List[int]] = None
) -> Optional[Any]:
"""Register objects with NIXL."""
if not keys:
return None
tuples = [(0, 0, key, "") for key in keys]
return self.registration._register_memory(tuples, "OBJ")
def _execute_transfer( def _execute_transfer(
self, tensors: List[torch.Tensor], keys: List[str], direction: str self,
buffers: Optional[List[torch.Tensor | tuple]],
keys: List[str],
direction: str,
) -> bool: ) -> bool:
if len(tensors) != len(keys): if len(buffers) != len(keys):
logger.error("Mismatch between number of tensors and files/objects") logger.error("Mismatch between number of tensors/buffers and files/objects")
return False return False
if not self.registration.register_buffers(tensors): # Registering file and object keys per transfer, to be updated when
logger.error("Failed to register tensors") # pre-registration for file and object is added to HiCache.
return False
# Get transfer tuples based on backend type
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
if self.backend_selector.mem_type == "FILE": if self.backend_selector.mem_type == "FILE":
file_tuples = self.file_manager.files_to_nixl_tuples(keys) tuples = self.file_manager.files_to_nixl_tuples(keys)
if not file_tuples or not self.registration.register_files(file_tuples): if not tuples or not self.registration._register_memory(tuples, "FILE"):
logger.error("Failed to prepare files for transfer") logger.error("Failed to prepare files for transfer")
return False return False
transfer_tuples = [ else: # mem_type == "OBJ"
(x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes) tuples = [(0, 0, key, "") for key in keys]
] if not tuples or not self.registration._register_memory(tuples, "OBJ"):
else:
if not self.registration.register_objects(keys, tensors):
logger.error("Failed to register objects") logger.error("Failed to register objects")
return False return False
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
# Prepare transfer descriptors
if isinstance(buffers[0], torch.Tensor):
tensor_sizes = [
tensor.element_size() * tensor.numel() for tensor in buffers
]
storage_tuples = [(x[0], s, x[2]) for x, s in zip(tuples, tensor_sizes)]
host_descs = self.agent.get_xfer_descs(buffers)
elif isinstance(buffers[0], tuple):
storage_tuples = [(x[0], y[1], x[2]) for x, y in zip(tuples, buffers)]
host_descs = self.agent.get_xfer_descs(
[(x[0], x[1], 0) for x in buffers], "DRAM"
)
else:
return False
storage_descs = self.agent.get_xfer_descs(
storage_tuples, self.backend_selector.mem_type
)
if (host_descs is None) or (storage_descs is None):
logger.error("Failed to get transfer descriptors")
return False
# Initialize transfer, default assumption that tensor was registered
try: try:
# Get transfer descriptors xfer_req = self.agent.initialize_xfer(
if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or ( direction, host_descs, storage_descs, self.agent_name
file_descs := self.agent.get_xfer_descs( )
transfer_tuples, self.backend_selector.mem_type except Exception:
) # Check if it was due to missing pre-registration
) is None: if not self.register_buffers(buffers):
logger.error("Failed to get transfer descriptors") logger.error("Failed to register tensors/buffers")
return False return False
# Initialize and execute transfer try:
if ( xfer_req = self.agent.initialize_xfer(
xfer_req := self.agent.initialize_xfer( direction, host_descs, storage_descs, self.agent_name
direction, tensor_descs, file_descs, self.agent_name
) )
) is None: except Exception as e:
logger.error("Failed to create transfer request") logger.error(f"Failed to create transfer request: {e}")
return False return False
# Execute transfer and wait for its completion
try:
state = self.agent.transfer(xfer_req) state = self.agent.transfer(xfer_req)
while state != "DONE": while state != "DONE":
state = self.agent.check_xfer_state(xfer_req) state = self.agent.check_xfer_state(xfer_req)
if state == "ERR": if state == "ERR":
self.agent.release_xfer_handle(xfer_req)
logger.error("Transfer failed") logger.error("Transfer failed")
return False return False
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
self.agent.release_xfer_handle(xfer_req)
return True return True
except Exception as e: except Exception as e:
...@@ -106,45 +158,87 @@ class HiCacheNixl(HiCacheStorage): ...@@ -106,45 +158,87 @@ class HiCacheNixl(HiCacheStorage):
logger.error(f"Traceback: {traceback.format_exc()}") logger.error(f"Traceback: {traceback.format_exc()}")
return False return False
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
if not keys:
return True
if self.backend_selector.mem_type == "FILE":
file_paths = []
for key in keys:
tensor_path = self.file_manager.get_file_path(key)
if not self.file_manager.create_file(tensor_path):
logger.error(f"Failed to create file {tensor_path}")
return False
file_paths.append(tensor_path)
return self._execute_transfer(values, file_paths, "WRITE")
else:
return self._execute_transfer(values, keys, "WRITE")
def set(self, key: str, value: torch.Tensor) -> bool:
return self.batch_set([key], [value])
def get( def get(
self, key: str, dst_tensor: Optional[torch.Tensor] = None self,
key: str,
target_location: Optional[torch.Tensor | int] = None,
target_sizes: Optional[int] = None,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
if dst_tensor is None: # To be removed, being compatible with the current API # To be removed, being compatible with the current API
if target_location is None:
return None return None
result = self.batch_get([key], [dst_tensor]) if target_sizes:
result = self.batch_get([key], [target_location], [target_sizes])
else:
result = self.batch_get([key], [target_location])
return result[0] if result else None return result[0] if result else None
def batch_get( def batch_get(
self, keys: List[str], dst_tensors: List[torch.Tensor] self,
) -> List[Optional[torch.Tensor]]: keys: List[str],
target_locations: Optional[List[torch.Tensor | int]] = None,
target_sizes: Optional[List[int]] = None,
) -> List[torch.Tensor | None]:
if not keys: if not keys:
return [] return []
# To be removed, being compatible with the current API
if not target_locations:
return [None] * len(keys)
if target_sizes and (len(target_sizes) != len(target_locations)):
logger.error("Mismatch between number of target_locations and target_sizes")
return [None] * len(keys)
if target_sizes:
dest = list(zip(target_locations, target_sizes))
else:
dest = target_locations
if self.backend_selector.mem_type == "FILE": if self.backend_selector.mem_type == "FILE":
file_paths = [self.file_manager.get_file_path(key) for key in keys] file_paths = [self.file_manager.get_file_path(key) for key in keys]
success = self._execute_transfer(dst_tensors, file_paths, "READ") success = self._execute_transfer(dest, file_paths, "READ")
else: else:
success = self._execute_transfer(dst_tensors, keys, "READ") success = self._execute_transfer(dest, keys, "READ")
return dst_tensors if success else [None] * len(keys) return target_locations if success and not target_sizes else [None] * len(keys)
def set(
self,
key: str,
value: Optional[torch.Tensor] = None,
target_location: Optional[int] = None,
target_sizes: Optional[int] = None,
) -> bool:
if target_location and target_sizes:
return self.batch_set([key], None, [target_location], [target_sizes])
else:
return self.batch_set([key], [value])
def batch_set(
self,
keys: List[str],
values: Optional[List[torch.Tensor]] = None,
target_locations: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None,
) -> bool:
if not keys or (not values and (not target_locations or not target_sizes)):
logger.error("Keys or values were not passed")
return False
if not values:
values = list(zip(target_locations, target_sizes))
if self.backend_selector.mem_type == "FILE":
file_paths = []
for key in keys:
file_path = self.file_manager.get_file_path(key)
# New file per set, to be updated when partial writes is added to HiCache
if not self.file_manager.create_file(file_path):
logger.error(f"Failed to create file {file_path}")
return False
file_paths.append(file_path)
return self._execute_transfer(values, file_paths, "WRITE")
else: # mem_type == "OBJ"
return self._execute_transfer(values, keys, "WRITE")
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
tuples = self.registration.create_query_tuples( tuples = self.registration.create_query_tuples(
......
...@@ -109,66 +109,35 @@ class NixlRegistration: ...@@ -109,66 +109,35 @@ class NixlRegistration:
return [(0, 0, key)] return [(0, 0, key)]
def _register_memory( def _register_memory(
self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str self,
items: Union[List[tuple], torch.Tensor, List[torch.Tensor]],
mem_type: Optional[str] = None,
) -> Optional[Any]: ) -> Optional[Any]:
"""Common registration logic for files, objects, and buffers. """Common registration logic for files, objects, and buffers.
Args: Args:
items: List of tuples or tensors to register items: List of tuples or tensors to register
mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM") mem_type: Memory type ("FILE", "OBJ") or None for tensor or list of tensors
desc: Description for logging
""" """
try: if isinstance(items, list) and not items:
if not items:
return None
reg_descs = self.agent.get_reg_descs(items, mem_type)
if reg_descs is None:
logger.error("Failed to create registration descriptors")
return None
registered_memory = self.agent.register_memory(reg_descs)
if registered_memory:
return registered_memory
else:
logger.error("Failed to register with NIXL")
return None
except Exception as e:
logger.error(f"Failed to register {desc}: {e}")
return None return None
def register_buffers( reg_descs = self.agent.get_reg_descs(items, mem_type)
self, buffers: Union[torch.Tensor, List[torch.Tensor]] if reg_descs is None:
) -> Optional[Any]: logger.error("Failed to create registration descriptors")
"""Register tensors/buffers with NIXL."""
if isinstance(buffers, torch.Tensor):
buffers = [buffers]
if not buffers:
return None return None
# Determine memory type based on tensor device try:
mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM" registered_memory = self.agent.register_memory(reg_descs)
return self._register_memory(buffers, mem_type, "buffers") return registered_memory # Could be None in case of error
except Exception as e:
def register_files(self, tuples: List[tuple]) -> Optional[Any]: if not mem_type:
"""Register files with NIXL using (0, 0, fd, file_path) tuples.""" logger.error(f"Failed to register Tensors with NIXL: {e}")
return self._register_memory(tuples, "FILE", "files") else:
logger.error(
def register_objects( f"Failed to register memory of type {mem_type} with NIXL: {e}"
self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None )
) -> Optional[Any]:
"""Register objects with NIXL."""
if not keys:
return None return None
# Create object tuples with proper sizes
tuples = [
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
for key, tensor in zip(keys, tensors or [None] * len(keys))
]
return self._register_memory(tuples, "OBJ", "objects")
class NixlFileManager: class NixlFileManager:
"""Handles file system operations for NIXL.""" """Handles file system operations for NIXL."""
...@@ -221,12 +190,9 @@ class NixlFileManager: ...@@ -221,12 +190,9 @@ class NixlFileManager:
return False return False
def files_to_nixl_tuples( def files_to_nixl_tuples(
self, file_paths: List[str], open_file: bool = True self, file_paths: List[str]
) -> List[Tuple[int, int, int, str]]: ) -> List[Tuple[int, int, int, str]]:
"""Create NIXL tuples (offset, length, fd, file_path) for given files.""" """Create NIXL tuples (offset, length, fd, file_path) for given files."""
if not open_file:
return [(0, 0, 0, path) for path in file_paths]
tuples = [] tuples = []
for path in file_paths: for path in file_paths:
if (fd := self.open_file(path)) is None: if (fd := self.open_file(path)) is None:
......
...@@ -7,8 +7,11 @@ from unittest.mock import MagicMock ...@@ -7,8 +7,11 @@ from unittest.mock import MagicMock
import torch import torch
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
NixlFileManager,
NixlRegistration,
)
class TestNixlUnified(unittest.TestCase): class TestNixlUnified(unittest.TestCase):
...@@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase): ...@@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase):
# Test get # Test get
retrieved = self.hicache.get(key, dst_tensor) retrieved = self.hicache.get(key, dst_tensor)
self.verify_tensors_equal(value, dst_tensor)
self.verify_tensors_equal(value, retrieved) self.verify_tensors_equal(value, retrieved)
# Same test in addr,len mode with another key and dst_tensor
key2 = "test_key2"
dst_tensor2 = torch.zeros_like(value, device="cpu")
src_addr, src_len = value.data_ptr(), value.numel() * value.element_size()
dst_addr, dst_len = (
dst_tensor2.data_ptr(),
dst_tensor2.numel() * dst_tensor2.element_size(),
)
# Test set
self.assertTrue(self.hicache.set(key, None, src_addr, src_len))
self.assertTrue(self.hicache.exists(key))
# Test get
retrieved2 = self.hicache.get(key, dst_addr, dst_len)
self.assertTrue(retrieved2 == None)
self.verify_tensors_equal(value, dst_tensor2)
def test_batch_set_get(self): def test_batch_set_get(self):
"""Test batch tensor set/get operations.""" """Test batch tensor set/get operations."""
keys = ["key1", "key2", "key3"] keys = ["key1", "key2", "key3"]
...@@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase): ...@@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase):
retrieved = self.hicache.batch_get(keys, dst_tensors) retrieved = self.hicache.batch_get(keys, dst_tensors)
self.verify_tensor_lists_equal(values, retrieved) self.verify_tensor_lists_equal(values, retrieved)
# Same test in addr,len mode with another key and dst_tensor
keys2 = ["key4", "key5", "key6"]
dst_tensors2 = [torch.zeros_like(v, device="cpu") for v in values]
src_addrs = [v.data_ptr() for v in values]
src_lens = [v.numel() * v.element_size() for v in values]
dst_addrs = [dt.data_ptr() for dt in dst_tensors2]
dst_lens = [dt.numel() * dt.element_size() for dt in dst_tensors2]
# Test batch set
self.assertTrue(self.hicache.batch_set(keys2, None, src_addrs, src_lens))
self.assertTrue(all(self.hicache.exists(key) for key in keys2))
# Test batch get
retrieved2 = self.hicache.batch_get(keys, dst_addrs, dst_lens)
self.assertTrue(all(ret == None for ret in retrieved2))
self.verify_tensor_lists_equal(values, dst_tensors2)
def test_mixed_operations(self): def test_mixed_operations(self):
"""Test mixing single and batch operations.""" """Test mixing single and batch operations."""
# Test interleaved set/get operations # Test interleaved set/get operations
...@@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase): ...@@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase):
self.file_manager.create_file(test_file) self.file_manager.create_file(test_file)
# Test tuple creation # Test tuple creation
tuples = self.file_manager.files_to_nixl_tuples([test_file], False) tuples = self.file_manager.files_to_nixl_tuples([test_file])
self.assertIsNotNone(tuples) self.assertIsNotNone(tuples)
self.assertTrue(len(tuples) > 0) self.assertTrue(len(tuples) > 0)
...@@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase): ...@@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase):
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
# Test buffer registration # Test buffer registration
self.assertIsNotNone(self.registration.register_buffers(tensor)) self.assertIsNotNone(self.hicache.register_buffers(tensor))
# Test batch registration # Test batch registration
tensors = [torch.randn(5, 5) for _ in range(3)] tensors = [torch.randn(5, 5) for _ in range(3)]
self.assertIsNotNone(self.registration.register_buffers(tensors)) self.assertIsNotNone(self.hicache.register_buffers(tensors))
def test_register_files_with_tuples(self): def test_register_files_with_tuples(self):
"""Test registration of files using NIXL tuples.""" """Test registration of files using NIXL tuples."""
...@@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase): ...@@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase):
self.file_manager.create_file(file) self.file_manager.create_file(file)
# Create tuples and register # Create tuples and register
tuples = self.file_manager.files_to_nixl_tuples(files, False) tuples = self.file_manager.files_to_nixl_tuples(files)
self.registration.register_files(tuples) self.hicache.register_files(tuples)
# Verify tuples # Verify tuples
self.assertEqual(len(tuples), len(files)) self.assertEqual(len(tuples), len(files))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment