Unverified Commit 106c2b31 authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

feat(hicache): Add generic hicache ci e2e test and benchmark test (#9846)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent c6756949
...@@ -136,13 +136,18 @@ class HiCacheFile(HiCacheStorage): ...@@ -136,13 +136,18 @@ class HiCacheFile(HiCacheStorage):
): ):
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
tp_rank, tp_size, model_name = ( tp_rank, tp_size, model_name, is_mla_model = (
storage_config.tp_rank, storage_config.tp_rank,
storage_config.tp_size, storage_config.tp_size,
storage_config.model_name, storage_config.model_name,
storage_config.is_mla_model,
) )
model_name = "-".join(model_name.split("/")) if model_name else "" model_name = "-".join(model_name.split("/")) if model_name else ""
if is_mla_model:
self.config_suffix = f"_{model_name}"
else:
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}" self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
if not os.path.exists(self.file_path) and tp_rank == 0: if not os.path.exists(self.file_path) and tp_rank == 0:
os.makedirs(self.file_path) os.makedirs(self.file_path)
logger.info(f"Created HiCacheFile storage directory at {self.file_path}") logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
......
"""
Benchmark tests for HiCache Storage functionality.
Usage:
python3 -m pytest test/srt/hicache/test_hicache_storage_benchmark.py -v
"""
import time
import unittest
from types import SimpleNamespace
from typing import Dict
import requests
from test_hicache_storage_e2e import HiCacheStorageBaseTest
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import is_in_ci, write_github_step_summary
class TestHiCacheStorageBenchmark(HiCacheStorageBaseTest):
"""Benchmark tests for HiCache Storage functionality"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args = {"--tp-size": 2, "--hicache-ratio": 1.5}
return server_args, {}
def flush_cache(self) -> bool:
"""Flush device cache to force remote storage access"""
try:
response = requests.post(f"{self.base_url}/flush_cache", timeout=10)
return response.status_code == 200
except requests.RequestException:
return False
# === Accuracy Tests ===
def test_eval_accuracy_with_cache_persistence(self):
"""Test eval accuracy with cache persistence across cache flushes"""
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
# First evaluation - populate cache
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
args_initial = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=400,
max_new_tokens=512,
parallel=32,
host=f"http://{self.base_host}",
port=int(self.base_port),
)
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
print(f"Evaluation metrics: {metrics_initial}")
self.assertGreater(metrics_initial["accuracy"], 0.60)
# Flush cache to force remote storage access
print("Phase 2: Flushing device cache...")
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
time.sleep(2)
# Second evaluation - should use remote cache
print("Phase 3: Running second GSM8K evaluation using remote cache...")
start_time = time.time()
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
cached_time = time.time() - start_time
print(f"Cached evaluation completed in {cached_time:.2f}s")
print(f"Cached accuracy: {metrics_cached['accuracy']:.3f}")
print(f"Cached throughput: {metrics_cached['output_throughput']:.2f} token/s")
# Verify accuracy consistency
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
print(f"Accuracy difference: {accuracy_diff:.4f}")
# Assertions
self.assertGreater(
metrics_initial["accuracy"], 0.5, "Initial accuracy should be reasonable"
)
self.assertGreater(
metrics_cached["accuracy"], 0.5, "Cached accuracy should be reasonable"
)
self.assertLess(
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
)
# Performance should be similar or better with cache
throughput_ratio = (
metrics_cached["output_throughput"] / metrics_initial["output_throughput"]
)
print(f"Throughput ratio (cached/initial): {throughput_ratio:.2f}")
if is_in_ci():
write_github_step_summary(
f"### HiCache Storage Accuracy Test\n"
f"Initial accuracy: {metrics_initial['accuracy']:.3f}\n"
f"Cached accuracy: {metrics_cached['accuracy']:.3f}\n"
f"Accuracy difference: {accuracy_diff:.4f}\n"
f"Throughput ratio: {throughput_ratio:.2f}\n"
)
# === Performance Benchmark Tests ===
def test_throughput_benchmark_with_hicache(self):
"""Benchmark throughput performance with HiCache enabled"""
print("\n=== Benchmarking Throughput with HiCache ===")
# throughput test
res1 = self._run_throughput_benchmark(
test_name="hicache_offline_throughput",
num_prompts=200,
request_rate=10,
additional_args=[],
)
# Flush cache to force remote storage access
print("Phase 2: Flushing device cache...")
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
time.sleep(2)
# Second benchmark, should use remote cache
res2 = self._run_throughput_benchmark(
test_name="hicache_online_throughput",
num_prompts=400,
request_rate=10,
additional_args=[],
)
if is_in_ci():
write_github_step_summary(
f"### HiCache Storage FileBackend Benchmark Test\n"
f"First time throughput: {res1['input_throughput']:.2f} token/s\n"
f"Second time throughput: {res2['input_throughput']:.2f} token/s\n"
f"First time TTFT: {res1['mean_ttft_ms']:.2f} ms\n"
f"Second time TTFT: {res2['mean_ttft_ms']:.2f} ms\n"
)
def _run_throughput_benchmark(
self,
test_name: str,
num_prompts: int,
request_rate: float,
dataset_name: str = "random",
additional_args: list = None,
) -> Dict:
"""Helper method to run throughput benchmarks"""
if additional_args is None:
additional_args = []
print(f"Running {test_name} benchmark...")
start_time = time.time()
try:
# Use the existing server instead of launching a new one
from sglang.bench_serving import run_benchmark
from sglang.test.test_utils import get_benchmark_args
args = get_benchmark_args(
base_url=self.base_url,
dataset_name=dataset_name,
tokenizer=self.model,
num_prompts=num_prompts,
request_rate=request_rate,
random_input_len=1024,
random_output_len=64,
)
# Run benchmark
result = run_benchmark(args)
elapsed_time = time.time() - start_time
print(f"{test_name} completed in {elapsed_time:.2f}s")
print(
f"Output throughput: {result.get('output_throughput', 0.0):.2f} token/s"
)
return result
except Exception as e:
print(f"Benchmark {test_name} failed: {e}")
# Fallback to avoid hard failure; return minimal metrics
return {
"output_throughput": 0.0,
"input_throughput": 0.0,
"mean_ttft_ms": float("inf"),
"mean_latency_ms": float("inf"),
"p99_ttft_ms": float("inf"),
}
if __name__ == "__main__":
unittest.main(verbosity=2)
"""
E2E tests for HiCache Storage functionality.
Usage:
python3 -m pytest test/srt/hicache/test_hicache_storage_e2e.py -v
"""
import os
import random
import tempfile
import time
import unittest
from typing import Dict
from urllib.parse import urlparse
import requests
from sglang.bench_serving import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class HiCacheStorageBaseTest(CustomTestCase):
"""Base test class with common setup and utilities"""
@classmethod
def setUpClass(cls):
"""Set up test environment and launch server once for all tests"""
cls.temp_dir = tempfile.mkdtemp()
cls.model = cls._get_model_name()
cls.base_url = DEFAULT_URL_FOR_TEST
parsed_url = urlparse(cls.base_url)
cls.base_host = parsed_url.hostname
cls.base_port = str(parsed_url.port)
# Prepare tokenizer for prompt generation
cls.tokenizer = get_tokenizer(cls.model)
# Launch server with HiCache enabled and cache report
cls.process = cls._launch_server_with_hicache()
cls._wait_for_server_ready()
print(f"Test server launched successfully at {cls.base_url}")
print(f"Cache directory: {cls.temp_dir}")
@classmethod
def tearDownClass(cls):
"""Clean up test environment"""
kill_process_tree(cls.process.pid)
import shutil
shutil.rmtree(cls.temp_dir, ignore_errors=True)
@classmethod
def _get_model_name(cls):
"""Get model name for the test configuration - override in subclasses"""
return DEFAULT_MODEL_NAME_FOR_TEST
@classmethod
def _get_base_server_args(cls):
"""Get base server arguments - can be extended in subclasses"""
return {
"--enable-hierarchical-cache": True,
"--mem-fraction-static": 0.6,
"--hicache-ratio": 1.2,
"--page-size": 64,
"--enable-cache-report": True,
"--hicache-storage-prefetch-policy": "wait_complete",
"--hicache-storage-backend": "file",
}
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
return {}, {"SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir}
@classmethod
def _launch_server_with_hicache(cls):
"""Launch server with HiCache enabled"""
additional_server_args, env_vars = cls._get_additional_server_args_and_env()
server_args = cls._get_base_server_args()
if additional_server_args:
server_args.update(additional_server_args)
final_server_args = []
for k, v in server_args.items():
if isinstance(v, bool):
final_server_args.append(str(k))
else:
final_server_args.append(str(k))
final_server_args.append(str(v))
print(f"final_server_args: {final_server_args}")
env_vars = {
**os.environ,
**env_vars,
}
return popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=final_server_args,
env=env_vars,
)
@classmethod
def _wait_for_server_ready(cls, timeout: int = 60) -> bool:
"""Wait for server to be ready"""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"{cls.base_url}/health", timeout=5)
if response.status_code == 200:
return True
except requests.RequestException:
pass
time.sleep(2)
raise TimeoutError("Server failed to start within timeout")
def send_request(
self, prompt: str, max_tokens: int = 100, temperature: float = 0.0
) -> Dict:
"""Send a generate request and return response"""
response = requests.post(
f"{self.base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"ignore_eos": True,
},
},
timeout=60,
)
self.assertEqual(
response.status_code,
200,
f"Request failed: {response.status_code} - {response.text}",
)
return response.json()
def get_cached_tokens(self, response_json: Dict) -> int:
"""Extract cached tokens count from /generate response"""
meta = response_json.get("meta_info", {})
return int(meta.get("cached_tokens", 0))
def flush_cache(self) -> bool:
"""Flush device cache to force remote storage access"""
try:
response = requests.post(f"{self.base_url}/flush_cache", timeout=10)
return response.status_code == 200
except requests.RequestException:
return False
def gen_prompt(self, token_num: int) -> str:
"""Generate a random prompt of specified token length using tokenizer vocabulary.
This function mimics the implementation from bench_serving.py to create
realistic prompts for testing cache behavior.
"""
all_available_tokens = list(self.tokenizer.get_vocab().values())
selected_tokens = random.choices(all_available_tokens, k=token_num)
return self.tokenizer.decode(selected_tokens)
def trigger_offloading_and_flush(self):
"""Helper method to trigger offloading and flush cache"""
# Trigger offloading
self.send_request(self.gen_prompt(1), max_tokens=150)
# Flush device cache to force remote storage access
time.sleep(2)
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
def test_basic_backup_and_prefetch(self):
"""Test storage and retrieval of large context through remote cache"""
print("\n=== Testing Large Context Cache Storage & Retrieval ===")
# Generate substantial context that will be cached
base_prompt = self.gen_prompt(768)
# First request - populate cache
print("Step 1: Populating cache with large context...")
response1 = self.send_request(base_prompt, max_tokens=150)
self.assertIsNotNone(response1)
# Flush device cache to force remote storage access
self.trigger_offloading_and_flush()
# Second request with extended prompt - should hit remote cache
print("Step 2: Testing cache hit from remote storage...")
extended_prompt = base_prompt + "\n\n" + self.gen_prompt(64)
start_time = time.time()
response2 = self.send_request(extended_prompt, max_tokens=150)
retrieval_time = time.time() - start_time
cached_tokens = self.get_cached_tokens(response2)
print(
f"Remote cache retrieval time: {retrieval_time:.3f}s, cached_tokens={cached_tokens}"
)
# Assert cached tokens indicate a remote hit
self.assertEqual(
cached_tokens, 768, "Expected significant cached tokens for remote hit"
)
class TestHiCacheStorageTP(HiCacheStorageBaseTest):
"""Multi-TP tests for HiCache Storage functionality"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args = {"--tp-size": 2}
return server_args, {}
class TestHiCacheStorageLayerFirstDirectIO(HiCacheStorageBaseTest):
"""Layer first direct tests for HiCache Storage functionality"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args = {
"--hicache-mem-layout": "layer_first",
"--hicache-io-backend": "direct",
}
return server_args, {}
class TestHiCacheStoragePageFirstLayout(HiCacheStorageBaseTest):
"""Page first layout tests for HiCache Storage functionality"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args = {"--hicache-mem-layout": "page_first"}
return server_args, {}
class TestHiCacheStorageMLA(HiCacheStorageBaseTest):
"""MLA Model tests for HiCache Storage functionality"""
@classmethod
def _get_model_name(cls):
"""Use MLA model for testing"""
return DEFAULT_MLA_MODEL_NAME_FOR_TEST
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args = {"--tp-size": 2}
return server_args, {}
# TODO: Add other backends tests(3fs/mooncake)
# class TestHiCacheStorageMooncakeBackend(HiCacheStorageBaseTest):
# """Mooncake backend tests for HiCache Storage functionality"""
# @classmethod
# def _get_additional_server_args_and_env(cls):
# """Get additional server arguments specific to configuration - override in subclasses"""
# server_args = ["--hicache-storage-backend", "mooncake"]
# env = {
# "MOONCAKE_TE_META_DATA_SERVER": "http://127.0.0.1:8080/metadata",
# "MOONCAKE_MASTER": "127.0.0.1:50051"
# xxxxx
# }
# return server_args, {}
if __name__ == "__main__":
unittest.main(verbosity=2)
...@@ -123,6 +123,8 @@ suites = { ...@@ -123,6 +123,8 @@ suites = {
TestFile("test_dp_attention.py", 277), TestFile("test_dp_attention.py", 277),
TestFile("test_patch_torch.py", 19), TestFile("test_patch_torch.py", 19),
TestFile("test_release_memory_occupation.py", 127), TestFile("test_release_memory_occupation.py", 127),
TestFile("hicache/test_hicache_storage_e2e.py", 400),
TestFile("hicache/test_hicache_storage_benchmark.py", 400),
], ],
"per-commit-4-gpu": [ "per-commit-4-gpu": [
TestFile("test_gpt_oss_4gpu.py", 600), TestFile("test_gpt_oss_4gpu.py", 600),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment