".github/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "08bf7545070764d41637e90290ca9b92b392263e"
Unverified Commit 2cd2e27f authored by Vishwanath Venkatesan's avatar Vishwanath Venkatesan Committed by GitHub
Browse files

SGLang HiCache NIXL Connector (#8488)


Signed-off-by: default avatarVishwanath Venkatesan <vvenkatesan@nvidia.com>
Co-authored-by: default avatarMoein Khazraee <moein@nvidia.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 743638bc
......@@ -265,6 +265,11 @@ class HiCacheController:
if storage_backend == "file":
self.storage_backend = HiCacheFile()
self.get_hash_str = get_hash_str
elif storage_backend == "nixl":
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
self.storage_backend = HiCacheNixl()
self.get_hash_str = get_hash_str
elif storage_backend == "mooncake":
self.storage_backend = MooncakeStore()
self.get_hash_str = get_hash_str_mooncake
......@@ -545,7 +550,11 @@ class HiCacheController:
def generic_page_transfer(self, operation, batch_size=8):
for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size]
page_data = self.storage_backend.batch_get(page_hashes)
# todo: zero copy
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
page_hashes
)
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
......@@ -679,7 +688,7 @@ class HiCacheController:
for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size]
page_data = [
self.mem_pool_host.get_flat_data_pages(
self.mem_pool_host.get_flat_data_page(
operation.host_indices[j * self.page_size]
)
for j in range(i, i + len(page_hashes))
......
......@@ -123,13 +123,22 @@ class HiCacheFile(HiCacheStorage):
key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin")
try:
# todo: fixing the target_location logic to enable in-place loading
loaded_tensor = torch.load(tensor_path)
if isinstance(loaded_tensor, torch.Tensor):
return loaded_tensor
if target_location is not None:
# Load directly into target_location's memory buffer
with open(tensor_path, "rb") as f:
target_location.set_(
torch.frombuffer(f.read(), dtype=target_location.dtype)
.reshape(target_location.shape)
.storage()
)
return target_location
else:
logger.error(f"Loaded data for key {key} is not a tensor.")
return None
loaded_tensor = torch.load(tensor_path)
if isinstance(loaded_tensor, torch.Tensor):
return loaded_tensor
else:
logger.error(f"Loaded data for key {key} is not a tensor.")
return None
except FileNotFoundError:
return None
......
......@@ -105,6 +105,14 @@ class HostKVCache(abc.ABC):
"""
raise NotImplementedError()
@abc.abstractmethod
def get_dummy_flat_data_page(self) -> torch.Tensor:
"""
Get a dummy flat data page from the host memory pool.
This is used for prefetching or initializing empty pages.
"""
raise NotImplementedError()
@abc.abstractmethod
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
"""
......@@ -256,6 +264,14 @@ class MHATokenToKVPoolHost(HostKVCache):
def get_flat_data_page(self, index) -> torch.Tensor:
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros(
(2, self.layer_num, self.page_size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
).flatten()
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
2,
......@@ -355,6 +371,19 @@ class MLATokenToKVPoolHost(HostKVCache):
def get_flat_data_page(self, index) -> torch.Tensor:
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros(
(
self.layer_num,
self.page_size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
).flatten()
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
self.layer_num,
......
# NIXL Integration for HiCache
This directory contains the **NIXL (NVIDIA Inference Xfer Library)** integration for **HiCache**, enabling high-performance storage across multiple backends.
NIXL provides a unified API for accessing various storage plugins, including but not limited to:
- **Deepseek's 3FS APIs** for high-throughput file operations
- **GPU Direct Storage (GDS)** for direct data movement between storage and GPU memory, bypassing CPU memory copies
- **Amazon S3-compatible object storage** for key-value access patterns
Additional backend integrations are planned for future releases.
## NIXL Resources
- **Project Repository**: [NIXL on GitHub](https://github.com/ai-dynamo/nixl)
- **Documentation**: [NIXL Documentation](https://github.com/ai-dynamo/nixl/tree/main/docs)
## Overview
The NIXL integration consists of two main files:
- **`hicache_nixl.py`** - Main HiCache storage connector using NIXL
- **`nixl_utils.py`** - Utility classes for backend selection, registration, and file management
## Components
### HiCacheNixl
The main storage connector that provides:
- Single and batch tensor set/get operations
- Automatic backend selection (3FS > POSIX > GDS_MT > GDS > OBJ)
- High-performance file-based (or) object based storage access using NIXL
### NixlUtils
Consolidated utility classes:
- **NixlBackendSelection** - Handles backend selection and creation
- **NixlRegistration** - Manages memory registration for tensors, files and objects
- **NixlFileManager** - Handles file system operations and NIXL tuple creation
## Running Unit Tests
### Prerequisites
- NIXL library installed and available (latest main required for supporting object query)
- PyTorch installed
- Python 3.8+
### Unit tests from Project root
Navigate to the project root directory (`/path/to/sglang`) and run:
#### Run all NIXL tests:
```bash
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -o asyncio_mode=strict
```
#### Run with verbose output:
```bash
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -o asyncio_mode=strict
```
Note: The `-v` flag provides more detailed output, showing each test case name and its result.
#### Run a specific test:
```bash
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
```
### From Tests Directory
Navigate to the tests directory and run:
```bash
cd test/srt
PYTHONPATH=../.. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
```
Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output.
## Test Coverage
Tests for this integration, a test suite can be found at `test_hicache_nixl_storage.py` which covers:
### HiCache Integration Tests (4 tests)
- Single tensor set/get operations
- Batch tensor set/get operations
- Mixed single and batch operations
- Data integrity for various tensor types
### File Management Tests (5 tests)
- Basic file operations
- NIXL tuple creation
- Error handling in file operations
### Registration Tests (2 tests)
- Tensor registration with memory type detection
- File registration using NIXL tuples
## Expected Output
When tests run successfully, you should see:
- NIXL agent initialization messages
- Backend selection messages (e.g., "Backend POSIX was instantiated")
- Test results with "ok" for passed tests
- Summary showing "Ran X tests in Y seconds" and "OK"
## Troubleshooting
### Import Errors
If you encounter `ModuleNotFoundError`, ensure:
- You're running from the correct directory
- `PYTHONPATH` is set correctly
- NIXL library is properly installed
### NIXL Errors
If NIXL operations fail:
- Check that NIXL is properly installed
- Verify that required plugins are available
- Ensure file permissions are correct for test directories
## File Structure
```
python/sglang/srt/mem_cache/nixl/
├── hicache_nixl.py # Main HiCache storage connector
├── nixl_utils.py # All NIXL utility classes
├── README.md # This file
└── tests/
└── test_nixl_unified.py # All tests in one file
```
## Dependencies
- **NIXL**: NVIDIA Inference Xfer Library (version 0.4 or later)
- Required plugins: POSIX (minimum), 3FS/GDS (optional for better performance)
- See [NIXL Installation Guide](https://github.com/ai-dynamo/nixl/blob/main/README.md#installation)
- **PyTorch**: For tensor operations (version 1.8 or later)
- **Python 3.8+**: For type hints and modern features
## Supported Features
### Memory Types
- **Tensor side**: multi-dimensional tensors of all numeric types (int32, int64, float32, float64) are supported.
- Tensors can be on CPU or GPU (as long as a GPU capable backend such as GDS_MT is available).
- Currently each tensor is mapped to a file or key, but it can be extended to support multiple keys per file or key.
- **Storage side**: file and object are supported through their relevant backends (e.g., 3FS or OBJ).
### Backend Priority
The NIXL backend selection follows this priority order:
1. **3FS** - Highest performance (if available)
- Best for high-throughput file operations using Deepseek 3FS APIs
2. **POSIX** - Standard file I/O (fallback)
- Universal compatibility
- Good for development and testing - Leverages both libaio/liburing
3. **GDS_MT** - Multi-threaded GDS (if available)
- Optimized for concurrent operations
- Supports GPU Direct storage with multiple light weight threads
4. **GDS** - GPU Direct Storage (if available)
- Direct GPU-storage data path
- Best for filesystems benefiting from batch operations and smaller IOs.
5. **OBJ** - Amazon S3 based Object Storage
- Key-value based storage
The system automatically selects the best available backend, with POSIX as the default fallback.
## Note
This is v0 of the NIXL connector. Future versions will focus on further performance optimizations such as memory pre-registration (pre-allocating and registering memory buffers to reduce registration overhead during transfers) and block merging (combining related blocks as offsets within the same file to reduce file operations and improve throughput). These optimizations require changes at a higher layer, as the current HiCache API doesn't expose information like block relationships or hash patterns that would enable these optimizations.
import hashlib
import logging
import os
import time
import uuid
from typing import Dict, List, Optional, Tuple, Union
import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
try:
from nixl._api import nixl_agent, nixl_agent_config
except ImportError as e:
raise ImportError(
"Please install NIXL by following the instructions at "
"https://github.com/ai-dynamo/nixl/blob/main/README.md "
"to use HiCacheNixl storage backend."
) from e
logger = logging.getLogger(__name__)
class HiCacheNixl(HiCacheStorage):
"""HiCacheNixl provides high-performance storage using NIXL plugins."""
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
"""Initialize NIXL storage connector."""
self.file_manager = (
NixlFileManager(file_path)
if plugin not in NixlBackendSelection.OBJ_PLUGINS
else None
)
agent_config = nixl_agent_config(backends=[])
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
self.agent = nixl_agent(self.agent_name, agent_config)
self.backend_selector = NixlBackendSelection(plugin)
if not self.backend_selector.create_backend(self.agent):
raise RuntimeError("Failed to create NIXL backend")
self.registration = NixlRegistration(self.agent)
def _execute_transfer(
self, tensors: List[torch.Tensor], keys: List[str], direction: str
) -> bool:
if len(tensors) != len(keys):
logger.error("Mismatch between number of tensors and files/objects")
return False
if not self.registration.register_buffers(tensors):
logger.error("Failed to register tensors")
return False
# Get transfer tuples based on backend type
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
if self.backend_selector.mem_type == "FILE":
file_tuples = self.file_manager.files_to_nixl_tuples(keys)
if not file_tuples or not self.registration.register_files(file_tuples):
logger.error("Failed to prepare files for transfer")
return False
transfer_tuples = [
(x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
]
else:
if not self.registration.register_objects(keys, tensors):
logger.error("Failed to register objects")
return False
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
try:
# Get transfer descriptors
if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
file_descs := self.agent.get_xfer_descs(
transfer_tuples, self.backend_selector.mem_type
)
) is None:
logger.error("Failed to get transfer descriptors")
return False
# Initialize and execute transfer
if (
xfer_req := self.agent.initialize_xfer(
direction, tensor_descs, file_descs, self.agent_name
)
) is None:
logger.error("Failed to create transfer request")
return False
state = self.agent.transfer(xfer_req)
while state != "DONE":
state = self.agent.check_xfer_state(xfer_req)
if state == "ERR":
logger.error("Transfer failed")
return False
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
return True
except Exception as e:
logger.error(f"Failed to execute transfer: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
return False
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
if not keys:
return True
if self.backend_selector.mem_type == "FILE":
file_paths = []
for key in keys:
tensor_path = self.file_manager.get_file_path(key)
if not self.file_manager.create_file(tensor_path):
logger.error(f"Failed to create file {tensor_path}")
return False
file_paths.append(tensor_path)
return self._execute_transfer(values, file_paths, "WRITE")
else:
return self._execute_transfer(values, keys, "WRITE")
def set(self, key: str, value: torch.Tensor) -> bool:
return self.batch_set([key], [value])
def get(
self, key: str, dst_tensor: Optional[torch.Tensor] = None
) -> torch.Tensor | None:
if dst_tensor is None: # To be removed, being compatible with the current API
return None
result = self.batch_get([key], [dst_tensor])
return result[0] if result else None
def batch_get(
self, keys: List[str], dst_tensors: List[torch.Tensor]
) -> List[Optional[torch.Tensor]]:
if not keys:
return []
if self.backend_selector.mem_type == "FILE":
file_paths = [self.file_manager.get_file_path(key) for key in keys]
success = self._execute_transfer(dst_tensors, file_paths, "READ")
else:
success = self._execute_transfer(dst_tensors, keys, "READ")
return dst_tensors if success else [None] * len(keys)
def exists(self, key: str) -> bool:
tuples = self.registration.create_query_tuples(
key,
self.backend_selector.mem_type,
self.file_manager if self.backend_selector.mem_type == "FILE" else None,
)
if not tuples:
return False
query_res = self.agent.query_memory(
tuples,
self.backend_selector.backend_name,
mem_type=self.backend_selector.mem_type,
)
return query_res[0] is not None # can be expanded to multiple keys
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
logger = logging.getLogger(__name__)
class NixlBackendSelection:
"""Handles NIXL backend selection and creation."""
# Priority order for File-based plugins in case of auto selection
FILE_PLUGINS = ["3FS", "POSIX", "GDS_MT", "GDS"]
# Priority order for File-based plugins in case of auto selection (add more as needed)
OBJ_PLUGINS = ["OBJ"] # Based on Amazon S3 SDK
def __init__(self, plugin: str = "auto"):
"""Initialize backend selection.
Args:
plugin: Plugin to use (default "auto" selects best available).
Can be a file plugin (3FS, POSIX, GDS, GDS_MT) or
an object plugin (OBJ).
"""
self.plugin = plugin
self.backend_name = None
self.mem_type = None
def set_bucket(self, bucket_name: str) -> None:
"""Set AWS bucket name in environment variable."""
os.environ["AWS_DEFAULT_BUCKET"] = bucket_name
logger.debug(f"Set AWS bucket name to: {bucket_name}")
def create_backend(self, agent) -> bool:
"""Create the appropriate NIXL backend based on configuration."""
try:
plugin_list = agent.get_plugin_list()
logger.debug(f"Available NIXL plugins: {plugin_list}")
# Handle explicit plugin selection or auto priority
if self.plugin == "auto":
# Try all file plugins first
for plugin in self.FILE_PLUGINS:
if plugin in plugin_list:
self.backend_name = plugin
break
# If no file plugin found, try object plugins
if not self.backend_name:
for plugin in self.OBJ_PLUGINS:
if plugin in plugin_list:
self.backend_name = plugin
break
else:
# Use explicitly requested plugin
self.backend_name = self.plugin
if self.backend_name not in plugin_list:
logger.error(
f"Backend {self.backend_name} not available in plugins: {plugin_list}"
)
return False
# Create backend and set memory type
if self.backend_name in self.OBJ_PLUGINS:
bucket = os.environ.get("AWS_DEFAULT_BUCKET")
if not bucket:
logger.error(
"AWS_DEFAULT_BUCKET environment variable must be set for object storage"
)
return False
agent.create_backend(self.backend_name, {"bucket": bucket})
else:
agent.create_backend(self.backend_name)
self.mem_type = "OBJ" if self.backend_name in self.OBJ_PLUGINS else "FILE"
logger.debug(
f"Created NIXL backend: {self.backend_name} with memory type: {self.mem_type}"
)
return True
except Exception as e:
logger.error(f"Failed to create NIXL backend: {e}")
return False
class NixlRegistration:
"""Handles NIXL memory registration."""
def __init__(self, agent):
self.agent = agent
def create_query_tuples(
self, key: str, mem_type: str, file_manager=None
) -> List[Tuple]:
"""Create NIXL tuples for querying memory.
Args:
key: Key to query (file path for FILE or object key for OBJ)
mem_type: Memory type ("FILE" or "OBJ")
file_manager: Optional NixlFileManager for FILE memory type
Returns:
List of NIXL tuples for querying
"""
if mem_type == "FILE":
if file_manager is None:
logger.error("file_manager required for FILE memory type")
return []
return [(0, 0, 0, file_manager.get_file_path(key))]
else: # OBJ
return [(0, 0, key)]
def _register_memory(
self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
) -> Optional[Any]:
"""Common registration logic for files, objects, and buffers.
Args:
items: List of tuples or tensors to register
mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
desc: Description for logging
"""
try:
if not items:
return None
reg_descs = self.agent.get_reg_descs(items, mem_type)
if reg_descs is None:
logger.error("Failed to create registration descriptors")
return None
registered_memory = self.agent.register_memory(reg_descs)
if registered_memory:
return registered_memory
else:
logger.error("Failed to register with NIXL")
return None
except Exception as e:
logger.error(f"Failed to register {desc}: {e}")
return None
def register_buffers(
self, buffers: Union[torch.Tensor, List[torch.Tensor]]
) -> Optional[Any]:
"""Register tensors/buffers with NIXL."""
if isinstance(buffers, torch.Tensor):
buffers = [buffers]
if not buffers:
return None
# Determine memory type based on tensor device
mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
return self._register_memory(buffers, mem_type, "buffers")
def register_files(self, tuples: List[tuple]) -> Optional[Any]:
"""Register files with NIXL using (0, 0, fd, file_path) tuples."""
return self._register_memory(tuples, "FILE", "files")
def register_objects(
self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
) -> Optional[Any]:
"""Register objects with NIXL."""
if not keys:
return None
# Create object tuples with proper sizes
tuples = [
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
for key, tensor in zip(keys, tensors or [None] * len(keys))
]
return self._register_memory(tuples, "OBJ", "objects")
class NixlFileManager:
"""Handles file system operations for NIXL."""
def __init__(self, base_dir: str):
"""
Initialize file manager.
Args:
base_dir: Base directory for storing tensor files
"""
self.base_dir = base_dir
if base_dir == "":
logger.debug(f"Initialized file manager without a base directory")
else:
os.makedirs(base_dir, exist_ok=True)
logger.debug(f"Initialized file manager with base directory: {base_dir}")
def get_file_path(self, key: str) -> str:
"""Get full file path for a given key."""
return os.path.join(self.base_dir, key)
def create_file(self, file_path: str) -> bool:
"""Create a file if it doesn't exist."""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
if not os.path.exists(file_path):
with open(file_path, "wb") as f:
pass # Create empty file
return True
except Exception as e:
logger.error(f"Failed to create file {file_path}: {e}")
return False
def open_file(self, file_path: str) -> Optional[int]:
"""Open a file and return its file descriptor."""
try:
fd = os.open(file_path, os.O_RDWR)
return fd
except Exception as e:
logger.error(f"Failed to open file {file_path}: {e}")
return None
def close_file(self, fd: int) -> bool:
"""Close a file descriptor."""
try:
os.close(fd)
return True
except Exception as e:
logger.error(f"Failed to close file descriptor {fd}: {e}")
return False
def files_to_nixl_tuples(
self, file_paths: List[str], open_file: bool = True
) -> List[Tuple[int, int, int, str]]:
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
if not open_file:
return [(0, 0, 0, path) for path in file_paths]
tuples = []
for path in file_paths:
if (fd := self.open_file(path)) is None:
# Clean up on failure
for t in tuples:
self.close_file(t[2])
return []
tuples.append((0, 0, fd, path))
return tuples
#!/usr/bin/env python3
import os
import unittest
from typing import List, Optional
from unittest.mock import MagicMock
import torch
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
class TestNixlUnified(unittest.TestCase):
"""Unified test suite for all NIXL components."""
def setUp(self):
"""Set up test environment."""
# Create test directories
self.test_dir = "/tmp/test_nixl_unified"
os.makedirs(self.test_dir, exist_ok=True)
# Mock NIXL agent for registration tests
self.mock_agent = MagicMock()
self.mock_agent.get_reg_descs.return_value = "mock_reg_descs"
self.mock_agent.register_memory.return_value = "mock_registered_memory"
# Create instances
self.file_manager = NixlFileManager(self.test_dir)
self.registration = NixlRegistration(self.mock_agent)
try:
self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX")
except ImportError:
self.skipTest("NIXL not available, skipping NIXL storage tests")
def tearDown(self):
"""Clean up test directories."""
if os.path.exists(self.test_dir):
import shutil
shutil.rmtree(self.test_dir)
def delete_test_file(self, file_path: str) -> bool:
"""Helper method to delete a test file.
Args:
file_path: Path to the file to delete
Returns:
bool: True if file was deleted or didn't exist, False on error
"""
try:
if os.path.exists(file_path):
os.remove(file_path)
return True
except Exception as e:
return False
def verify_tensors_equal(self, expected: torch.Tensor, actual: torch.Tensor):
"""Helper to verify tensor equality."""
self.assertIsNotNone(actual, "Retrieved tensor is None")
self.assertTrue(
torch.allclose(expected, actual, atol=1e-6),
f"Tensors not equal:\nExpected: {expected}\nActual: {actual}",
)
def verify_tensor_lists_equal(
self, expected: List[torch.Tensor], actual: List[torch.Tensor]
):
"""Helper to verify lists of tensors are equal."""
self.assertEqual(len(expected), len(actual), "Lists have different lengths")
for exp, act in zip(expected, actual):
self.verify_tensors_equal(exp, act)
# ============================================================================
# HiCache Integration Tests
# ============================================================================
def test_single_set_get(self):
"""Test single tensor set/get operations."""
key = "test_key"
value = torch.randn(10, 10, device="cpu")
dst_tensor = torch.zeros_like(value, device="cpu")
# Test set
self.assertTrue(self.hicache.set(key, value))
self.assertTrue(self.hicache.exists(key))
# Test get
retrieved = self.hicache.get(key, dst_tensor)
self.verify_tensors_equal(value, retrieved)
def test_batch_set_get(self):
"""Test batch tensor set/get operations."""
keys = ["key1", "key2", "key3"]
values = [
torch.randn(5, 5, device="cpu"),
torch.randn(3, 3, device="cpu"),
torch.randn(7, 7, device="cpu"),
]
dst_tensors = [torch.zeros_like(v, device="cpu") for v in values]
# Test batch set
self.assertTrue(self.hicache.batch_set(keys, values))
self.assertTrue(all(self.hicache.exists(key) for key in keys))
# Test batch get
retrieved = self.hicache.batch_get(keys, dst_tensors)
self.verify_tensor_lists_equal(values, retrieved)
def test_mixed_operations(self):
"""Test mixing single and batch operations."""
# Test interleaved set/get operations
key1, key2 = "key1", "key2"
value1 = torch.randn(4, 4, device="cpu")
value2 = torch.randn(6, 6, device="cpu")
dst1 = torch.zeros_like(value1)
dst2 = torch.zeros_like(value2)
# Single set/get
self.assertTrue(self.hicache.set(key1, value1))
retrieved1 = self.hicache.get(key1, dst1)
self.verify_tensors_equal(value1, retrieved1)
# Batch set/get
self.assertTrue(self.hicache.batch_set([key2], [value2]))
retrieved2 = self.hicache.batch_get([key2], [dst2])
self.verify_tensors_equal(value2, retrieved2[0])
def test_data_integrity(self):
"""Test data integrity across operations."""
# Test with various tensor types and sizes
test_cases = [
("float32", torch.randn(10, 10, dtype=torch.float32)),
("float64", torch.randn(5, 5, dtype=torch.float64)),
("int32", torch.randint(-100, 100, (8, 8), dtype=torch.int32)),
("int64", torch.randint(-100, 100, (6, 6), dtype=torch.int64)),
("bool", torch.randint(0, 2, (4, 4)).bool()),
]
for name, tensor in test_cases:
with self.subTest(tensor_type=name):
key = f"test_{name}"
dst_tensor = torch.zeros_like(tensor)
# Set and immediately get
self.assertTrue(self.hicache.set(key, tensor))
retrieved1 = self.hicache.get(key, dst_tensor)
self.verify_tensors_equal(tensor, retrieved1)
# Get again to verify persistence
dst_tensor.zero_()
retrieved2 = self.hicache.get(key, dst_tensor)
self.verify_tensors_equal(tensor, retrieved2)
def test_basic_file_operations(self):
"""Test basic file operations."""
test_file = os.path.join(self.test_dir, "test_file.bin")
self.file_manager.create_file(test_file)
self.assertTrue(os.path.exists(test_file))
self.assertEqual(os.path.getsize(test_file), 0) # Empty file
# Test file deletion
self.assertTrue(self.delete_test_file(test_file))
self.assertFalse(os.path.exists(test_file))
def test_create_nixl_tuples(self):
"""Test creation of NIXL tuples."""
test_file = os.path.join(self.test_dir, "test_file.bin")
self.file_manager.create_file(test_file)
# Test tuple creation
tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
self.assertIsNotNone(tuples)
self.assertTrue(len(tuples) > 0)
def test_error_handling(self):
"""Test error handling in file operations."""
# Test non-existent file
self.assertTrue(
self.delete_test_file("nonexistent_file.bin")
) # Returns True if file doesn't exist
# Test invalid file path
self.assertFalse(self.file_manager.create_file("")) # Empty path should fail
def test_register_buffers(self):
"""Test registration of memory buffers."""
# Create test tensor
tensor = torch.randn(10, 10)
# Test buffer registration
self.assertIsNotNone(self.registration.register_buffers(tensor))
# Test batch registration
tensors = [torch.randn(5, 5) for _ in range(3)]
self.assertIsNotNone(self.registration.register_buffers(tensors))
def test_register_files_with_tuples(self):
"""Test registration of files using NIXL tuples."""
files = [os.path.join(self.test_dir, f"test_file_{i}.bin") for i in range(3)]
for file in files:
self.file_manager.create_file(file)
# Create tuples and register
tuples = self.file_manager.files_to_nixl_tuples(files, False)
self.registration.register_files(tuples)
# Verify tuples
self.assertEqual(len(tuples), len(files))
for t, f in zip(tuples, files):
self.assertEqual(t[3], f) # Check file path
if __name__ == "__main__":
unittest.main()
......@@ -1471,7 +1471,7 @@ class ServerArgs:
parser.add_argument(
"--hicache-storage-backend",
type=str,
choices=["file", "mooncake", "hf3fs"],
choices=["file", "mooncake", "hf3fs", "nixl"],
default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment