Unverified Commit ec99668a authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

[Hicache]: Add E2E CI For 3FS-KVStore (#10131)

parent 78f13981
import logging
import os
import threading
from abc import ABC, abstractmethod
from typing import List
import torch
class Hf3fsClient(ABC):
"""Abstract interface for HF3FS clients."""
@abstractmethod
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
"""Initialize the HF3FS client.
Args:
path: File path for storage
size: Total size of storage file
bytes_per_page: Bytes per page
entries: Number of entries for batch operations
"""
pass
@abstractmethod
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
"""Batch read from storage."""
pass
@abstractmethod
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
"""Batch write to storage."""
pass
@abstractmethod
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
"""Validate batch operation parameters."""
pass
@abstractmethod
def get_size(self) -> int:
"""Get total storage size."""
pass
@abstractmethod
def close(self) -> None:
"""Close the client and cleanup resources."""
pass
@abstractmethod
def flush(self) -> None:
"""Flush data to disk."""
pass
logger = logging.getLogger(__name__)
class Hf3fsMockClient(Hf3fsClient):
"""Mock implementation of Hf3fsClient for CI testing purposes."""
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
"""Initialize mock HF3FS client."""
self.path = path
self.size = size
self.bytes_per_page = bytes_per_page
self.entries = entries
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.path), exist_ok=True)
# Create and initialize the file
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
os.ftruncate(self.file, size)
logger.info(
f"Hf3fsMockClient initialized: path={path}, size={size}, "
f"bytes_per_page={bytes_per_page}, entries={entries}"
)
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
"""Batch read from mock storage."""
self.check(offsets, tensors)
results = []
for offset, tensor in zip(offsets, tensors):
size = tensor.numel() * tensor.itemsize
try:
os.lseek(self.file, offset, os.SEEK_SET)
bytes_read = os.read(self.file, size)
if len(bytes_read) == size:
# Convert bytes to tensor and copy to target
bytes_tensor = torch.frombuffer(bytes_read, dtype=torch.uint8)
typed_tensor = bytes_tensor.view(tensor.dtype).view(tensor.shape)
tensor.copy_(typed_tensor)
results.append(size)
else:
logger.warning(
f"Short read: expected {size}, got {len(bytes_read)}"
)
results.append(len(bytes_read))
except Exception as e:
logger.error(f"Error reading from offset {offset}: {e}")
results.append(0)
return results
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
"""Batch write to mock storage."""
self.check(offsets, tensors)
results = []
for offset, tensor in zip(offsets, tensors):
size = tensor.numel() * tensor.itemsize
try:
# Convert tensor to bytes and write directly to file
tensor_bytes = tensor.contiguous().view(torch.uint8).flatten()
data = tensor_bytes.numpy().tobytes()
os.lseek(self.file, offset, os.SEEK_SET)
bytes_written = os.write(self.file, data)
if bytes_written == size:
results.append(size)
else:
logger.warning(f"Short write: expected {size}, got {bytes_written}")
results.append(bytes_written)
except Exception as e:
logger.error(f"Error writing to offset {offset}: {e}")
results.append(0)
return results
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
"""Validate batch operation parameters."""
pass
def get_size(self) -> int:
"""Get total storage size."""
return self.size
def close(self) -> None:
"""Close the mock client and cleanup resources."""
try:
if hasattr(self, "file") and self.file >= 0:
os.close(self.file)
self.file = -1 # Mark as closed
logger.info(f"MockHf3fsClient closed: {self.path}")
except Exception as e:
logger.error(f"Error closing MockHf3fsClient: {e}")
def flush(self) -> None:
"""Flush data to disk."""
try:
os.fsync(self.file)
except Exception as e:
logger.error(f"Error flushing MockHf3fsClient: {e}")
...@@ -9,6 +9,8 @@ from typing import List ...@@ -9,6 +9,8 @@ from typing import List
import torch import torch
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
root = Path(__file__).parent.resolve() root = Path(__file__).parent.resolve()
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"]) hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
...@@ -51,7 +53,9 @@ def wsynchronized(): ...@@ -51,7 +53,9 @@ def wsynchronized():
return _decorator return _decorator
class Hf3fsClient: class Hf3fsUsrBioClient(Hf3fsClient):
"""HF3FS client implementation using usrbio."""
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int): def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
if not HF3FS_AVAILABLE: if not HF3FS_AVAILABLE:
raise ImportError( raise ImportError(
......
...@@ -13,7 +13,7 @@ from typing import Any, List, Optional, Tuple ...@@ -13,7 +13,7 @@ from typing import Any, List, Optional, Tuple
import torch import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
from sglang.srt.metrics.collector import StorageMetrics from sglang.srt.metrics.collector import StorageMetrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -114,6 +114,33 @@ def synchronized(): ...@@ -114,6 +114,33 @@ def synchronized():
return _decorator return _decorator
def create_hf3fs_client(
path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False
) -> Hf3fsClient:
"""Factory function to create appropriate HF3FS client.
Args:
path: File path for storage
size: Total size of storage file
bytes_per_page: Bytes per page
entries: Number of entries for batch operations
use_mock: Whether to use mock client instead of real usrbio client
Returns:
"""
if use_mock:
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsMockClient
logger.info(f"[Rank Using Hf3fsMockClient for testing")
return Hf3fsMockClient(path, size, bytes_per_page, entries)
else:
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import (
Hf3fsUsrBioClient,
)
return Hf3fsUsrBioClient(path, size, bytes_per_page, entries)
class HiCacheHF3FS(HiCacheStorage): class HiCacheHF3FS(HiCacheStorage):
"""HiCache backend that stores KV cache pages in HF3FS files.""" """HiCache backend that stores KV cache pages in HF3FS files."""
...@@ -131,6 +158,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -131,6 +158,7 @@ class HiCacheHF3FS(HiCacheStorage):
metadata_client: Hf3fsMetadataInterface, metadata_client: Hf3fsMetadataInterface,
is_mla_model: bool = False, is_mla_model: bool = False,
is_page_first_layout: bool = False, is_page_first_layout: bool = False,
use_mock_client: bool = False,
): ):
self.rank = rank self.rank = rank
self.file_path = file_path self.file_path = file_path
...@@ -159,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -159,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage):
self.ac = AtomicCounter(self.numjobs) self.ac = AtomicCounter(self.numjobs)
self.clients = [ self.clients = [
Hf3fsClient( create_hf3fs_client(
self.file_path, self.file_size, self.bytes_per_page, self.entries self.file_path,
self.file_size,
self.bytes_per_page,
self.entries,
use_mock_client,
) )
for _ in range(numjobs) for _ in range(numjobs)
] ]
...@@ -202,14 +234,24 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -202,14 +234,24 @@ class HiCacheHF3FS(HiCacheStorage):
Hf3fsLocalMetadataClient, Hf3fsLocalMetadataClient,
) )
use_mock_client = False
if storage_config is not None: if storage_config is not None:
rank, is_mla_model, is_page_first_layout = ( rank, is_mla_model, is_page_first_layout = (
storage_config.tp_rank, storage_config.tp_rank,
storage_config.is_mla_model, storage_config.is_mla_model,
storage_config.is_page_first_layout, storage_config.is_page_first_layout,
) )
if storage_config.extra_config is not None:
use_mock_client = storage_config.extra_config.get(
"use_mock_hf3fs_client", False
)
else: else:
rank, is_mla_model, is_page_first_layout = 0, False, False rank, is_mla_model, is_page_first_layout = (
0,
False,
False,
)
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md" mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
...@@ -228,6 +270,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -228,6 +270,7 @@ class HiCacheHF3FS(HiCacheStorage):
dtype=dtype, dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(), metadata_client=Hf3fsLocalMetadataClient(),
is_page_first_layout=is_page_first_layout, is_page_first_layout=is_page_first_layout,
use_mock_client=use_mock_client,
) )
try: try:
...@@ -277,6 +320,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -277,6 +320,7 @@ class HiCacheHF3FS(HiCacheStorage):
metadata_client=metadata_client, metadata_client=metadata_client,
is_mla_model=is_mla_model, is_mla_model=is_mla_model,
is_page_first_layout=is_page_first_layout, is_page_first_layout=is_page_first_layout,
use_mock_client=use_mock_client,
) )
def get( def get(
......
"""
Benchmark tests for HiCache Storage with 3FS backend.
Usage:
python3 -m pytest test/srt/hicache/test_hicache_storage_3fs_backend.py -v
"""
import json
import os
import time
import unittest
from types import SimpleNamespace
from test_hicache_storage_file_backend import HiCacheStorageBaseMixin
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import CustomTestCase
class HiCacheStorage3FSBackendBaseMixin(HiCacheStorageBaseMixin):
"""Base mixin class with common setup and utilities"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
# Create a temporary JSON config file for HF3FS
hf3fs_config = {
"file_path_prefix": os.path.join(cls.temp_dir, "hicache"),
"file_size": 1024 * 1024 * 1024 * 2,
"numjobs": 2,
"entries": 8,
"use_mock_hf3fs_client": True,
}
# Write config to temporary file
config_file = os.path.join(cls.temp_dir, "hf3fs_config.json")
with open(config_file, "w") as f:
json.dump(hf3fs_config, f, indent=2)
server_args = {
"--tp-size": 1,
"--hicache-ratio": 1.2,
"--hicache-storage-backend": "hf3fs",
"--hicache-storage-backend-extra-config": json.dumps(hf3fs_config),
}
# Set the environment variable to point to our config file
env_vars = {
"SGLANG_HICACHE_HF3FS_CONFIG_PATH": config_file,
}
return server_args, env_vars
class TestHf3fsBackendLayerFirstLayout(
HiCacheStorage3FSBackendBaseMixin, CustomTestCase
):
"""Layer first layout tests for HiCache-Hf3fs backend"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args, env_vars = super()._get_additional_server_args_and_env()
server_args["--hicache-mem-layout"] = "layer_first"
server_args["--hicache-io-backend"] = "direct"
return server_args, env_vars
class TestHf3fsBackendPageFirstLayout(
HiCacheStorage3FSBackendBaseMixin, CustomTestCase
):
"""Page first layout tests for HiCache-Hf3fs backend"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args, env_vars = super()._get_additional_server_args_and_env()
server_args["--hicache-mem-layout"] = "page_first"
return server_args, env_vars
class TestHf3fsBackendAccuracy(HiCacheStorage3FSBackendBaseMixin, CustomTestCase):
"""Accuracy tests for HiCache-Hf3fs backend"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args, env_vars = super()._get_additional_server_args_and_env()
server_args["--hicache-ratio"] = 1.5
server_args["--tp-size"] = 2
return server_args, env_vars
def test_eval_accuracy(self):
"""Test eval accuracy with cache persistence across cache flushes"""
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
# First evaluation - populate cache
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
args_initial = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=50,
max_new_tokens=512,
parallel=10,
host=f"http://{self.base_host}",
port=int(self.base_port),
)
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
# Flush cache to force remote storage access
print("Phase 2: Flushing device cache...")
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
time.sleep(2)
# Second evaluation - should use remote cache
print("Phase 3: Running second GSM8K evaluation using remote cache...")
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
# Verify accuracy consistency
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
print(f"Accuracy difference: {accuracy_diff:.4f}")
# Assertions
self.assertGreater(
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
)
self.assertGreater(
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
)
self.assertLess(
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
)
if __name__ == "__main__":
unittest.main(verbosity=2)
"""
Benchmark tests for HiCache Storage functionality.
Usage:
python3 -m pytest test/srt/hicache/test_hicache_storage_benchmark.py -v
"""
import time
import unittest
from types import SimpleNamespace
from typing import Dict
import requests
from test_hicache_storage_e2e import HiCacheStorageBaseTest
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import is_in_ci, write_github_step_summary
class TestHiCacheStorageBenchmark(HiCacheStorageBaseTest):
"""Benchmark tests for HiCache Storage functionality"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args = {"--tp-size": 2, "--hicache-ratio": 1.5}
return server_args, {}
def flush_cache(self) -> bool:
"""Flush device cache to force remote storage access"""
try:
response = requests.post(f"{self.base_url}/flush_cache", timeout=10)
return response.status_code == 200
except requests.RequestException:
return False
# === Accuracy Tests ===
def test_eval_accuracy_with_cache_persistence(self):
"""Test eval accuracy with cache persistence across cache flushes"""
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
# First evaluation - populate cache
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
args_initial = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=400,
max_new_tokens=512,
parallel=32,
host=f"http://{self.base_host}",
port=int(self.base_port),
)
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
print(f"Evaluation metrics: {metrics_initial}")
self.assertGreater(metrics_initial["accuracy"], 0.60)
# Flush cache to force remote storage access
print("Phase 2: Flushing device cache...")
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
time.sleep(2)
# Second evaluation - should use remote cache
print("Phase 3: Running second GSM8K evaluation using remote cache...")
start_time = time.time()
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
cached_time = time.time() - start_time
print(f"Cached evaluation completed in {cached_time:.2f}s")
print(f"Cached accuracy: {metrics_cached['accuracy']:.3f}")
print(f"Cached throughput: {metrics_cached['output_throughput']:.2f} token/s")
# Verify accuracy consistency
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
print(f"Accuracy difference: {accuracy_diff:.4f}")
# Assertions
self.assertGreater(
metrics_initial["accuracy"], 0.5, "Initial accuracy should be reasonable"
)
self.assertGreater(
metrics_cached["accuracy"], 0.5, "Cached accuracy should be reasonable"
)
self.assertLess(
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
)
# Performance should be similar or better with cache
throughput_ratio = (
metrics_cached["output_throughput"] / metrics_initial["output_throughput"]
)
print(f"Throughput ratio (cached/initial): {throughput_ratio:.2f}")
if is_in_ci():
write_github_step_summary(
f"### HiCache Storage Accuracy Test\n"
f"Initial accuracy: {metrics_initial['accuracy']:.3f}\n"
f"Cached accuracy: {metrics_cached['accuracy']:.3f}\n"
f"Accuracy difference: {accuracy_diff:.4f}\n"
f"Throughput ratio: {throughput_ratio:.2f}\n"
)
# === Performance Benchmark Tests ===
def test_throughput_benchmark_with_hicache(self):
"""Benchmark throughput performance with HiCache enabled"""
print("\n=== Benchmarking Throughput with HiCache ===")
# throughput test
res1 = self._run_throughput_benchmark(
test_name="hicache_offline_throughput",
num_prompts=200,
request_rate=10,
additional_args=[],
)
# Flush cache to force remote storage access
print("Phase 2: Flushing device cache...")
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
time.sleep(2)
# Second benchmark, should use remote cache
res2 = self._run_throughput_benchmark(
test_name="hicache_online_throughput",
num_prompts=400,
request_rate=10,
additional_args=[],
)
if is_in_ci():
write_github_step_summary(
f"### HiCache Storage FileBackend Benchmark Test\n"
f"First time throughput: {res1['input_throughput']:.2f} token/s\n"
f"Second time throughput: {res2['input_throughput']:.2f} token/s\n"
f"First time TTFT: {res1['mean_ttft_ms']:.2f} ms\n"
f"Second time TTFT: {res2['mean_ttft_ms']:.2f} ms\n"
)
def _run_throughput_benchmark(
self,
test_name: str,
num_prompts: int,
request_rate: float,
dataset_name: str = "random",
additional_args: list = None,
) -> Dict:
"""Helper method to run throughput benchmarks"""
if additional_args is None:
additional_args = []
print(f"Running {test_name} benchmark...")
start_time = time.time()
try:
# Use the existing server instead of launching a new one
from sglang.bench_serving import run_benchmark
from sglang.test.test_utils import get_benchmark_args
args = get_benchmark_args(
base_url=self.base_url,
dataset_name=dataset_name,
tokenizer=self.model,
num_prompts=num_prompts,
request_rate=request_rate,
random_input_len=1024,
random_output_len=64,
)
# Run benchmark
result = run_benchmark(args)
elapsed_time = time.time() - start_time
print(f"{test_name} completed in {elapsed_time:.2f}s")
print(
f"Output throughput: {result.get('output_throughput', 0.0):.2f} token/s"
)
return result
except Exception as e:
print(f"Benchmark {test_name} failed: {e}")
# Fallback to avoid hard failure; return minimal metrics
return {
"output_throughput": 0.0,
"input_throughput": 0.0,
"mean_ttft_ms": float("inf"),
"mean_latency_ms": float("inf"),
"p99_ttft_ms": float("inf"),
}
if __name__ == "__main__":
unittest.main(verbosity=2)
...@@ -9,6 +9,7 @@ import random ...@@ -9,6 +9,7 @@ import random
import tempfile import tempfile
import time import time
import unittest import unittest
from types import SimpleNamespace
from typing import Dict from typing import Dict
from urllib.parse import urlparse from urllib.parse import urlparse
...@@ -16,6 +17,7 @@ import requests ...@@ -16,6 +17,7 @@ import requests
from sglang.bench_serving import get_tokenizer from sglang.bench_serving import get_tokenizer
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
...@@ -26,8 +28,8 @@ from sglang.test.test_utils import ( ...@@ -26,8 +28,8 @@ from sglang.test.test_utils import (
) )
class HiCacheStorageBaseTest(CustomTestCase): class HiCacheStorageBaseMixin:
"""Base test class with common setup and utilities""" """Base mixin class with common setup and utilities"""
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -166,11 +168,7 @@ class HiCacheStorageBaseTest(CustomTestCase): ...@@ -166,11 +168,7 @@ class HiCacheStorageBaseTest(CustomTestCase):
return False return False
def gen_prompt(self, token_num: int) -> str: def gen_prompt(self, token_num: int) -> str:
"""Generate a random prompt of specified token length using tokenizer vocabulary. """Generate a random prompt of specified token length using tokenizer vocabulary."""
This function mimics the implementation from bench_serving.py to create
realistic prompts for testing cache behavior.
"""
all_available_tokens = list(self.tokenizer.get_vocab().values()) all_available_tokens = list(self.tokenizer.get_vocab().values())
selected_tokens = random.choices(all_available_tokens, k=token_num) selected_tokens = random.choices(all_available_tokens, k=token_num)
return self.tokenizer.decode(selected_tokens) return self.tokenizer.decode(selected_tokens)
...@@ -201,10 +199,9 @@ class HiCacheStorageBaseTest(CustomTestCase): ...@@ -201,10 +199,9 @@ class HiCacheStorageBaseTest(CustomTestCase):
# Second request with extended prompt - should hit remote cache # Second request with extended prompt - should hit remote cache
print("Step 2: Testing cache hit from remote storage...") print("Step 2: Testing cache hit from remote storage...")
extended_prompt = base_prompt + "\n\n" + self.gen_prompt(64)
start_time = time.time() start_time = time.time()
response2 = self.send_request(extended_prompt, max_tokens=150) response2 = self.send_request(base_prompt, max_tokens=150)
retrieval_time = time.time() - start_time retrieval_time = time.time() - start_time
cached_tokens = self.get_cached_tokens(response2) cached_tokens = self.get_cached_tokens(response2)
...@@ -213,12 +210,12 @@ class HiCacheStorageBaseTest(CustomTestCase): ...@@ -213,12 +210,12 @@ class HiCacheStorageBaseTest(CustomTestCase):
) )
# Assert cached tokens indicate a remote hit # Assert cached tokens indicate a remote hit
self.assertEqual( self.assertGreater(
cached_tokens, 768, "Expected significant cached tokens for remote hit" cached_tokens, 700, "Expected significant cached tokens for remote hit"
) )
class TestHiCacheStorageTP(HiCacheStorageBaseTest): class TestHiCacheStorageTP(HiCacheStorageBaseMixin, CustomTestCase):
"""Multi-TP tests for HiCache Storage functionality""" """Multi-TP tests for HiCache Storage functionality"""
@classmethod @classmethod
...@@ -228,7 +225,7 @@ class TestHiCacheStorageTP(HiCacheStorageBaseTest): ...@@ -228,7 +225,7 @@ class TestHiCacheStorageTP(HiCacheStorageBaseTest):
return server_args, {} return server_args, {}
class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseTest): class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseMixin, CustomTestCase):
"""Layer first direct tests for HiCache Storage functionality""" """Layer first direct tests for HiCache Storage functionality"""
@classmethod @classmethod
...@@ -241,7 +238,7 @@ class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseTest): ...@@ -241,7 +238,7 @@ class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseTest):
return server_args, {} return server_args, {}
class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseTest): class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseMixin, CustomTestCase):
"""Page first layout tests for HiCache Storage functionality""" """Page first layout tests for HiCache Storage functionality"""
@classmethod @classmethod
...@@ -251,7 +248,7 @@ class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseTest): ...@@ -251,7 +248,7 @@ class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseTest):
return server_args, {} return server_args, {}
class TestHiCacheStorageMLA(HiCacheStorageBaseTest): class TestHiCacheStorageMLA(HiCacheStorageBaseMixin, CustomTestCase):
"""MLA Model tests for HiCache Storage functionality""" """MLA Model tests for HiCache Storage functionality"""
@classmethod @classmethod
...@@ -266,6 +263,57 @@ class TestHiCacheStorageMLA(HiCacheStorageBaseTest): ...@@ -266,6 +263,57 @@ class TestHiCacheStorageMLA(HiCacheStorageBaseTest):
return server_args, {} return server_args, {}
class TestHiCacheStorageAccuracy(HiCacheStorageBaseMixin, CustomTestCase):
"""Accuracy tests for HiCache Storage functionality"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args = {"--tp-size": 2, "--hicache-ratio": 1.5}
return server_args, {}
def test_eval_accuracy(self):
"""Test eval accuracy with cache persistence across cache flushes"""
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
# First evaluation - populate cache
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
args_initial = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=50,
max_new_tokens=512,
parallel=10,
host=f"http://{self.base_host}",
port=int(self.base_port),
)
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
# Flush cache to force remote storage access
print("Phase 2: Flushing device cache...")
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
time.sleep(2)
# Second evaluation - should use remote cache
print("Phase 3: Running second GSM8K evaluation using remote cache...")
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
# Verify accuracy consistency
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
print(f"Accuracy difference: {accuracy_diff:.4f}")
# Assertions
self.assertGreater(
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
)
self.assertGreater(
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
)
self.assertLess(
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
)
# TODO: Add other backends tests(3fs/mooncake) # TODO: Add other backends tests(3fs/mooncake)
# class TestHiCacheStorageMooncakeBackend(HiCacheStorageBaseTest): # class TestHiCacheStorageMooncakeBackend(HiCacheStorageBaseTest):
# """Mooncake backend tests for HiCache Storage functionality""" # """Mooncake backend tests for HiCache Storage functionality"""
......
...@@ -125,8 +125,8 @@ suites = { ...@@ -125,8 +125,8 @@ suites = {
TestFile("test_dp_attention.py", 277), TestFile("test_dp_attention.py", 277),
TestFile("test_patch_torch.py", 19), TestFile("test_patch_torch.py", 19),
TestFile("test_release_memory_occupation.py", 127), TestFile("test_release_memory_occupation.py", 127),
TestFile("hicache/test_hicache_storage_e2e.py", 400), TestFile("hicache/test_hicache_storage_file_backend.py", 400),
TestFile("hicache/test_hicache_storage_benchmark.py", 400), TestFile("hicache/test_hicache_storage_3fs_backend.py", 400),
], ],
"per-commit-4-gpu": [ "per-commit-4-gpu": [
TestFile("test_gpt_oss_4gpu.py", 600), TestFile("test_gpt_oss_4gpu.py", 600),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment