Unverified Commit 7e880286 authored by Moein Khazraee's avatar Moein Khazraee Committed by GitHub
Browse files

Add support for extensions of interface and pre-registrations to NIXL HiCache (#9211)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 446c8e4c
......@@ -36,6 +36,21 @@ Consolidated utility classes:
- **NixlRegistration** - Manages memory registration for tensors, files and objects
- **NixlFileManager** - Handles file system operations and NIXL tuple creation
## Using NIXL for HiCache backend
When running the SGLang server, indicate `nixl` for `hicache-storage-backend` parameter, for instance:
```bash
python3 -m sglang.launch_server --model-path <model> --host <ip> --port <port> --page-size 64 --enable-hierarchical-cache --hicache-ratio 2 --hicache-size 64 --hicache-write-policy write_through --hicache-storage-backend nixl
```
To customize the base directory for files, you can set the following environment variable:
```bash
export SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR=/path/to/desired/dir
```
Selection of any storage backend like 3FS requires availability of that library on the system, and the backend is selected based on the priority mentioned above.
## Running Unit Tests
### Prerequisites
......@@ -43,33 +58,26 @@ Consolidated utility classes:
- PyTorch installed
- Python 3.8+
### Unit tests from Project root
Navigate to the project root directory (`/path/to/sglang`) and run:
### Unit tests from current directory
From the current directory run:
#### Run all NIXL tests:
```bash
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -o asyncio_mode=strict
PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
```
#### Run with verbose output:
```bash
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -o asyncio_mode=strict
PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -o asyncio_mode=strict
```
Note: The `-v` flag provides more detailed output, showing each test case name and its result.
#### Run a specific test:
```bash
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
PYTHONPATH=. python -m pytest test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
```
### From Tests Directory
Navigate to the tests directory and run:
```bash
cd test/srt
PYTHONPATH=../.. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
```
Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output.
## Test Coverage
......
......@@ -3,7 +3,7 @@ import logging
import os
import time
import uuid
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
......@@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage):
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
"""Initialize NIXL storage connector."""
# Might be better to be unified across HiCache backends and moved to HiCacheController
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
self.file_manager = (
NixlFileManager(file_path)
if plugin not in NixlBackendSelection.OBJ_PLUGINS
......@@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage):
self.registration = NixlRegistration(self.agent)
def register_buffers(
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
) -> Optional[Any]:
"""Register tensor(s) or target locations in host memory (list of addr,len tuples) with NIXL."""
if isinstance(buffers[0], tuple):
tuples = [(x[0], x[1], 0, "") for x in buffers]
return self.registration._register_memory(tuples, "DRAM")
else:
return self.registration._register_memory(buffers)
def register_files(
self, file_paths: List[str], open_file: Optional[bool] = True
) -> Optional[Any]:
"""Register files with NIXL."""
tuples = self.file_manager.files_to_nixl_tuples(file_paths)
return self.registration._register_memory(tuples, "FILE")
def register_objects(
self, keys: List[str], sizes: Optional[List[int]] = None
) -> Optional[Any]:
"""Register objects with NIXL."""
if not keys:
return None
tuples = [(0, 0, key, "") for key in keys]
return self.registration._register_memory(tuples, "OBJ")
def _execute_transfer(
self, tensors: List[torch.Tensor], keys: List[str], direction: str
self,
buffers: Optional[List[torch.Tensor | tuple]],
keys: List[str],
direction: str,
) -> bool:
if len(tensors) != len(keys):
logger.error("Mismatch between number of tensors and files/objects")
if len(buffers) != len(keys):
logger.error("Mismatch between number of tensors/buffers and files/objects")
return False
if not self.registration.register_buffers(tensors):
logger.error("Failed to register tensors")
return False
# Get transfer tuples based on backend type
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
# Registering file and object keys per transfer, to be updated when
# pre-registration for file and object is added to HiCache.
if self.backend_selector.mem_type == "FILE":
file_tuples = self.file_manager.files_to_nixl_tuples(keys)
if not file_tuples or not self.registration.register_files(file_tuples):
tuples = self.file_manager.files_to_nixl_tuples(keys)
if not tuples or not self.registration._register_memory(tuples, "FILE"):
logger.error("Failed to prepare files for transfer")
return False
transfer_tuples = [
(x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
]
else:
if not self.registration.register_objects(keys, tensors):
else: # mem_type == "OBJ"
tuples = [(0, 0, key, "") for key in keys]
if not tuples or not self.registration._register_memory(tuples, "OBJ"):
logger.error("Failed to register objects")
return False
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
# Prepare transfer descriptors
if isinstance(buffers[0], torch.Tensor):
tensor_sizes = [
tensor.element_size() * tensor.numel() for tensor in buffers
]
storage_tuples = [(x[0], s, x[2]) for x, s in zip(tuples, tensor_sizes)]
host_descs = self.agent.get_xfer_descs(buffers)
elif isinstance(buffers[0], tuple):
storage_tuples = [(x[0], y[1], x[2]) for x, y in zip(tuples, buffers)]
host_descs = self.agent.get_xfer_descs(
[(x[0], x[1], 0) for x in buffers], "DRAM"
)
else:
return False
storage_descs = self.agent.get_xfer_descs(
storage_tuples, self.backend_selector.mem_type
)
if (host_descs is None) or (storage_descs is None):
logger.error("Failed to get transfer descriptors")
return False
# Initialize transfer, default assumption that tensor was registered
try:
# Get transfer descriptors
if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
file_descs := self.agent.get_xfer_descs(
transfer_tuples, self.backend_selector.mem_type
)
) is None:
logger.error("Failed to get transfer descriptors")
xfer_req = self.agent.initialize_xfer(
direction, host_descs, storage_descs, self.agent_name
)
except Exception:
# Check if it was due to missing pre-registration
if not self.register_buffers(buffers):
logger.error("Failed to register tensors/buffers")
return False
# Initialize and execute transfer
if (
xfer_req := self.agent.initialize_xfer(
direction, tensor_descs, file_descs, self.agent_name
try:
xfer_req = self.agent.initialize_xfer(
direction, host_descs, storage_descs, self.agent_name
)
) is None:
logger.error("Failed to create transfer request")
except Exception as e:
logger.error(f"Failed to create transfer request: {e}")
return False
# Execute transfer and wait for its completion
try:
state = self.agent.transfer(xfer_req)
while state != "DONE":
state = self.agent.check_xfer_state(xfer_req)
if state == "ERR":
self.agent.release_xfer_handle(xfer_req)
logger.error("Transfer failed")
return False
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
self.agent.release_xfer_handle(xfer_req)
return True
except Exception as e:
......@@ -106,45 +158,87 @@ class HiCacheNixl(HiCacheStorage):
logger.error(f"Traceback: {traceback.format_exc()}")
return False
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
if not keys:
return True
if self.backend_selector.mem_type == "FILE":
file_paths = []
for key in keys:
tensor_path = self.file_manager.get_file_path(key)
if not self.file_manager.create_file(tensor_path):
logger.error(f"Failed to create file {tensor_path}")
return False
file_paths.append(tensor_path)
return self._execute_transfer(values, file_paths, "WRITE")
else:
return self._execute_transfer(values, keys, "WRITE")
def set(self, key: str, value: torch.Tensor) -> bool:
return self.batch_set([key], [value])
def get(
self, key: str, dst_tensor: Optional[torch.Tensor] = None
self,
key: str,
target_location: Optional[torch.Tensor | int] = None,
target_sizes: Optional[int] = None,
) -> torch.Tensor | None:
if dst_tensor is None: # To be removed, being compatible with the current API
# To be removed, being compatible with the current API
if target_location is None:
return None
result = self.batch_get([key], [dst_tensor])
if target_sizes:
result = self.batch_get([key], [target_location], [target_sizes])
else:
result = self.batch_get([key], [target_location])
return result[0] if result else None
def batch_get(
self, keys: List[str], dst_tensors: List[torch.Tensor]
) -> List[Optional[torch.Tensor]]:
self,
keys: List[str],
target_locations: Optional[List[torch.Tensor | int]] = None,
target_sizes: Optional[List[int]] = None,
) -> List[torch.Tensor | None]:
if not keys:
return []
# To be removed, being compatible with the current API
if not target_locations:
return [None] * len(keys)
if target_sizes and (len(target_sizes) != len(target_locations)):
logger.error("Mismatch between number of target_locations and target_sizes")
return [None] * len(keys)
if target_sizes:
dest = list(zip(target_locations, target_sizes))
else:
dest = target_locations
if self.backend_selector.mem_type == "FILE":
file_paths = [self.file_manager.get_file_path(key) for key in keys]
success = self._execute_transfer(dst_tensors, file_paths, "READ")
success = self._execute_transfer(dest, file_paths, "READ")
else:
success = self._execute_transfer(dst_tensors, keys, "READ")
return dst_tensors if success else [None] * len(keys)
success = self._execute_transfer(dest, keys, "READ")
return target_locations if success and not target_sizes else [None] * len(keys)
def set(
self,
key: str,
value: Optional[torch.Tensor] = None,
target_location: Optional[int] = None,
target_sizes: Optional[int] = None,
) -> bool:
if target_location and target_sizes:
return self.batch_set([key], None, [target_location], [target_sizes])
else:
return self.batch_set([key], [value])
def batch_set(
self,
keys: List[str],
values: Optional[List[torch.Tensor]] = None,
target_locations: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None,
) -> bool:
if not keys or (not values and (not target_locations or not target_sizes)):
logger.error("Keys or values were not passed")
return False
if not values:
values = list(zip(target_locations, target_sizes))
if self.backend_selector.mem_type == "FILE":
file_paths = []
for key in keys:
file_path = self.file_manager.get_file_path(key)
# New file per set, to be updated when partial writes is added to HiCache
if not self.file_manager.create_file(file_path):
logger.error(f"Failed to create file {file_path}")
return False
file_paths.append(file_path)
return self._execute_transfer(values, file_paths, "WRITE")
else: # mem_type == "OBJ"
return self._execute_transfer(values, keys, "WRITE")
def exists(self, key: str) -> bool:
tuples = self.registration.create_query_tuples(
......
......@@ -109,66 +109,35 @@ class NixlRegistration:
return [(0, 0, key)]
def _register_memory(
self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
self,
items: Union[List[tuple], torch.Tensor, List[torch.Tensor]],
mem_type: Optional[str] = None,
) -> Optional[Any]:
"""Common registration logic for files, objects, and buffers.
Args:
items: List of tuples or tensors to register
mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
desc: Description for logging
mem_type: Memory type ("FILE", "OBJ") or None for tensor or list of tensors
"""
try:
if not items:
return None
reg_descs = self.agent.get_reg_descs(items, mem_type)
if reg_descs is None:
logger.error("Failed to create registration descriptors")
return None
registered_memory = self.agent.register_memory(reg_descs)
if registered_memory:
return registered_memory
else:
logger.error("Failed to register with NIXL")
return None
except Exception as e:
logger.error(f"Failed to register {desc}: {e}")
if isinstance(items, list) and not items:
return None
def register_buffers(
self, buffers: Union[torch.Tensor, List[torch.Tensor]]
) -> Optional[Any]:
"""Register tensors/buffers with NIXL."""
if isinstance(buffers, torch.Tensor):
buffers = [buffers]
if not buffers:
reg_descs = self.agent.get_reg_descs(items, mem_type)
if reg_descs is None:
logger.error("Failed to create registration descriptors")
return None
# Determine memory type based on tensor device
mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
return self._register_memory(buffers, mem_type, "buffers")
def register_files(self, tuples: List[tuple]) -> Optional[Any]:
"""Register files with NIXL using (0, 0, fd, file_path) tuples."""
return self._register_memory(tuples, "FILE", "files")
def register_objects(
self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
) -> Optional[Any]:
"""Register objects with NIXL."""
if not keys:
try:
registered_memory = self.agent.register_memory(reg_descs)
return registered_memory # Could be None in case of error
except Exception as e:
if not mem_type:
logger.error(f"Failed to register Tensors with NIXL: {e}")
else:
logger.error(
f"Failed to register memory of type {mem_type} with NIXL: {e}"
)
return None
# Create object tuples with proper sizes
tuples = [
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
for key, tensor in zip(keys, tensors or [None] * len(keys))
]
return self._register_memory(tuples, "OBJ", "objects")
class NixlFileManager:
"""Handles file system operations for NIXL."""
......@@ -221,12 +190,9 @@ class NixlFileManager:
return False
def files_to_nixl_tuples(
self, file_paths: List[str], open_file: bool = True
self, file_paths: List[str]
) -> List[Tuple[int, int, int, str]]:
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
if not open_file:
return [(0, 0, 0, path) for path in file_paths]
tuples = []
for path in file_paths:
if (fd := self.open_file(path)) is None:
......
......@@ -7,8 +7,11 @@ from unittest.mock import MagicMock
import torch
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
NixlFileManager,
NixlRegistration,
)
class TestNixlUnified(unittest.TestCase):
......@@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase):
# Test get
retrieved = self.hicache.get(key, dst_tensor)
self.verify_tensors_equal(value, dst_tensor)
self.verify_tensors_equal(value, retrieved)
# Same test in addr,len mode with another key and dst_tensor
key2 = "test_key2"
dst_tensor2 = torch.zeros_like(value, device="cpu")
src_addr, src_len = value.data_ptr(), value.numel() * value.element_size()
dst_addr, dst_len = (
dst_tensor2.data_ptr(),
dst_tensor2.numel() * dst_tensor2.element_size(),
)
# Test set
self.assertTrue(self.hicache.set(key, None, src_addr, src_len))
self.assertTrue(self.hicache.exists(key))
# Test get
retrieved2 = self.hicache.get(key, dst_addr, dst_len)
self.assertTrue(retrieved2 == None)
self.verify_tensors_equal(value, dst_tensor2)
def test_batch_set_get(self):
"""Test batch tensor set/get operations."""
keys = ["key1", "key2", "key3"]
......@@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase):
retrieved = self.hicache.batch_get(keys, dst_tensors)
self.verify_tensor_lists_equal(values, retrieved)
# Same test in addr,len mode with another key and dst_tensor
keys2 = ["key4", "key5", "key6"]
dst_tensors2 = [torch.zeros_like(v, device="cpu") for v in values]
src_addrs = [v.data_ptr() for v in values]
src_lens = [v.numel() * v.element_size() for v in values]
dst_addrs = [dt.data_ptr() for dt in dst_tensors2]
dst_lens = [dt.numel() * dt.element_size() for dt in dst_tensors2]
# Test batch set
self.assertTrue(self.hicache.batch_set(keys2, None, src_addrs, src_lens))
self.assertTrue(all(self.hicache.exists(key) for key in keys2))
# Test batch get
retrieved2 = self.hicache.batch_get(keys, dst_addrs, dst_lens)
self.assertTrue(all(ret == None for ret in retrieved2))
self.verify_tensor_lists_equal(values, dst_tensors2)
def test_mixed_operations(self):
"""Test mixing single and batch operations."""
# Test interleaved set/get operations
......@@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase):
self.file_manager.create_file(test_file)
# Test tuple creation
tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
tuples = self.file_manager.files_to_nixl_tuples([test_file])
self.assertIsNotNone(tuples)
self.assertTrue(len(tuples) > 0)
......@@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase):
tensor = torch.randn(10, 10)
# Test buffer registration
self.assertIsNotNone(self.registration.register_buffers(tensor))
self.assertIsNotNone(self.hicache.register_buffers(tensor))
# Test batch registration
tensors = [torch.randn(5, 5) for _ in range(3)]
self.assertIsNotNone(self.registration.register_buffers(tensors))
self.assertIsNotNone(self.hicache.register_buffers(tensors))
def test_register_files_with_tuples(self):
"""Test registration of files using NIXL tuples."""
......@@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase):
self.file_manager.create_file(file)
# Create tuples and register
tuples = self.file_manager.files_to_nixl_tuples(files, False)
self.registration.register_files(tuples)
tuples = self.file_manager.files_to_nixl_tuples(files)
self.hicache.register_files(tuples)
# Verify tuples
self.assertEqual(len(tuples), len(files))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment