Unverified Commit 77098aea authored by Teng Ma's avatar Teng Ma Committed by GitHub
Browse files

[HiCache] Add tests for hicache storage mooncake backend (#10171)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
Co-authored-by: default avatarhzh0425 <hzh0425@apache.org>
Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent 5ccf0b03
...@@ -74,7 +74,7 @@ fi ...@@ -74,7 +74,7 @@ fi
$PIP_CMD list $PIP_CMD list
# Install additional dependencies # Install additional dependencies
$PIP_CMD install mooncake-transfer-engine==0.3.5 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX $PIP_CMD install mooncake-transfer-engine==0.3.6 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
if [ "$IS_BLACKWELL" != "1" ]; then if [ "$IS_BLACKWELL" != "1" ]; then
# For lmms_evals evaluating MMMU # For lmms_evals evaluating MMMU
......
"""
Benchmark tests for HiCache Storage with Mooncake backend.
Usage:
python3.10 -m pytest test/srt/hicache/test_hicache_storage_mooncake_backend.py -v
"""
import json
import os
import subprocess
import time
import unittest
from types import SimpleNamespace
import requests
from test_hicache_storage_file_backend import HiCacheStorageBaseMixin
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
CustomTestCase,
find_available_port,
)
class HiCacheStorageMooncakeBackendBaseMixin(HiCacheStorageBaseMixin):
"""Base mixin class with common setup and utilities"""
# Default port ranges for Mooncake services - can be overridden in subclasses
mooncake_master_port_base = 50051
mooncake_metadata_port_base = 8080
@classmethod
def setUpClass(cls):
"""Set up test environment and launch Mooncake services before server setup"""
# Find available ports for Mooncake services to avoid conflicts
cls.mooncake_master_port = find_available_port(
HiCacheStorageMooncakeBackendBaseMixin.mooncake_master_port_base
)
cls.mooncake_metadata_port = find_available_port(
HiCacheStorageMooncakeBackendBaseMixin.mooncake_metadata_port_base
)
# Start Mooncake services first
cls._start_mooncake_services()
# Call parent setup
super().setUpClass()
@classmethod
def tearDownClass(cls):
"""Clean up Mooncake services after server teardown"""
# Call parent teardown first
super().tearDownClass()
# Stop Mooncake services
cls._stop_mooncake_services()
@classmethod
def _start_mooncake_services(cls):
"""Start Mooncake metadata and master services with configurable ports and readiness detection"""
print("Starting Mooncake services...")
print(
f"Using master port: {cls.mooncake_master_port}, metadata port: {cls.mooncake_metadata_port}"
)
# Start metadata service with configurable port
try:
# Start metadata server with port configuration
cls.metadata_service_process = subprocess.Popen(
[
"python3",
"-m",
"mooncake.http_metadata_server",
"--port",
str(cls.mooncake_metadata_port),
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid, # Create new process group
)
print(
f"Mooncake metadata service started on port {cls.mooncake_metadata_port}"
)
except (FileNotFoundError, subprocess.SubprocessError) as e:
print(f"Warning: Could not start Mooncake metadata service: {e}")
cls.metadata_service_process = None
# Start master service with configurable port
try:
# Start master server with port configuration
cls.master_service_process = subprocess.Popen(
["mooncake_master", "--port", str(cls.mooncake_master_port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid, # Create new process group
)
print(f"Mooncake master service started on port {cls.mooncake_master_port}")
except (FileNotFoundError, subprocess.SubprocessError) as e:
print(f"Warning: Could not start Mooncake master service: {e}")
cls.master_service_process = None
# Wait for services to be ready instead of fixed sleep
cls._wait_for_mooncake_services_ready()
@classmethod
def _wait_for_mooncake_services_ready(cls, timeout: int = 30) -> bool:
"""Wait for Mooncake services to be ready by checking their endpoints"""
print("Waiting for Mooncake services to be ready...")
start_time = time.time()
services_ready = False
while time.time() - start_time < timeout:
try:
# Check metadata service
metadata_ready = False
if (
cls.metadata_service_process
and cls.metadata_service_process.poll() is None
):
try:
# Try to connect to the metadata service
metadata_url = (
f"http://127.0.0.1:{cls.mooncake_metadata_port}/metadata"
)
response = requests.get(metadata_url, timeout=2)
if response.status_code == 200:
metadata_ready = True
print("Mooncake metadata service is ready")
except (requests.RequestException, ConnectionError):
# Service might not be fully started yet
pass
# Check master service (if it has a health endpoint)
master_ready = False
if (
cls.master_service_process
and cls.master_service_process.poll() is None
):
# For now, we'll assume master service is ready if process is running
# and it's been a few seconds since startup
if (
time.time() - start_time > 5
): # Give master service time to initialize
master_ready = True
print("Mooncake master service is ready")
# Both services should be ready
if metadata_ready and master_ready:
services_ready = True
print("All Mooncake services are ready")
break
except Exception as e:
print(f"Error checking service readiness: {e}")
time.sleep(2)
if not services_ready:
print(
"Warning: Mooncake services may not be fully ready, continuing anyway..."
)
return services_ready
@classmethod
def _stop_mooncake_services(cls):
"""Stop Mooncake services"""
print("Stopping Mooncake services...")
# Stop metadata service
if hasattr(cls, "metadata_service_process") and cls.metadata_service_process:
try:
os.killpg(os.getpgid(cls.metadata_service_process.pid), 9)
cls.metadata_service_process.wait(timeout=5)
print("Mooncake metadata service stopped")
except (ProcessLookupError, subprocess.TimeoutExpired, OSError) as e:
print(f"Warning: Could not stop Mooncake metadata service: {e}")
# Stop master service
if hasattr(cls, "master_service_process") and cls.master_service_process:
try:
os.killpg(os.getpgid(cls.master_service_process.pid), 9)
cls.master_service_process.wait(timeout=5)
print("Mooncake master service stopped")
except (ProcessLookupError, subprocess.TimeoutExpired, OSError) as e:
print(f"Warning: Could not stop Mooncake master service: {e}")
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args = {
"--tp-size": 1,
"--hicache-ratio": 2,
"--hicache-storage-backend": "mooncake",
}
# Set the environment variables for Mooncake using dynamic ports
env_vars = {
"MOONCAKE_MASTER": f"127.0.0.1:{cls.mooncake_master_port}",
"MOONCAKE_PROTOCOL": "rdma",
"MOONCAKE_DEVICE": "mlx5_roce0,mlx5_roce1",
"MOONCAKE_TE_META_DATA_SERVER": f"http://127.0.0.1:{cls.mooncake_metadata_port}/metadata",
"MOONCAKE_GLOBAL_SEGMENT_SIZE": "4294967296", # 4 GiB
}
return server_args, env_vars
'''
# Same as #10131, layer first layout test TODO(mateng): will make it work
class TestMooncakeBackendLayerFirstLayout(
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
):
"""Layer first layout tests for HiCache-Mooncake backend"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args, env_vars = super()._get_additional_server_args_and_env()
server_args["--hicache-mem-layout"] = "layer_first"
server_args["--hicache-io-backend"] = "direct"
return server_args, env_vars
'''
class TestMooncakeBackendPageFirstLayout(
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
):
"""Page first layout tests for HiCache-Mooncake backend"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args, env_vars = super()._get_additional_server_args_and_env()
server_args["--hicache-mem-layout"] = "page_first"
return server_args, env_vars
class TestMooncakeBackendMLAModel(
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
):
"""MLA Model tests for HiCache-Mooncake backend"""
@classmethod
def _get_model_name(cls):
"""Use MLA model for testing"""
return DEFAULT_MLA_MODEL_NAME_FOR_TEST
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args, env_vars = super()._get_additional_server_args_and_env()
server_args["--hicache-mem-layout"] = "page_first"
server_args["--tp-size"] = 2
return server_args, env_vars
class TestMooncakeBackendAccuracy(
HiCacheStorageMooncakeBackendBaseMixin, CustomTestCase
):
"""Accuracy tests for HiCache-Mooncake backend"""
@classmethod
def _get_additional_server_args_and_env(cls):
"""Get additional server arguments specific to configuration - override in subclasses"""
server_args, env_vars = super()._get_additional_server_args_and_env()
server_args["--hicache-ratio"] = 1.5
server_args["--tp-size"] = 2
return server_args, env_vars
def test_eval_accuracy(self):
"""Test eval accuracy with cache persistence across cache flushes"""
print("\n=== Testing Eval Accuracy with Cache Persistence ===")
# First evaluation - populate cache
print("Phase 1: Running initial GSM8K evaluation to populate cache...")
args_initial = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=50,
max_new_tokens=512,
parallel=10,
host=f"http://{self.base_host}",
port=int(self.base_port),
)
metrics_initial = run_eval_few_shot_gsm8k(args_initial)
# Flush cache to force remote storage access
print("Phase 2: Flushing device cache...")
self.assertTrue(self.flush_cache(), "Cache flush should succeed")
time.sleep(2)
# Second evaluation - should use remote cache
print("Phase 3: Running second GSM8K evaluation using remote cache...")
metrics_cached = run_eval_few_shot_gsm8k(args_initial)
# Verify accuracy consistency
accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
print(f"Accuracy difference: {accuracy_diff:.4f}")
# Assertions
self.assertGreater(
metrics_initial["accuracy"], 0.6, "Initial accuracy should be reasonable"
)
self.assertGreater(
metrics_cached["accuracy"], 0.6, "Cached accuracy should be reasonable"
)
self.assertLess(
accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
)
if __name__ == "__main__":
unittest.main(verbosity=2)
...@@ -142,6 +142,7 @@ suites = { ...@@ -142,6 +142,7 @@ suites = {
"per-commit-8-gpu": [ "per-commit-8-gpu": [
# Disabled because it hangs on the CI. # Disabled because it hangs on the CI.
# TestFile("ep/test_moe_ep.py", 181), # TestFile("ep/test_moe_ep.py", 181),
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 800),
TestFile("lora/test_lora_llama4.py", 600), TestFile("lora/test_lora_llama4.py", 600),
TestFile("test_disaggregation.py", 499), TestFile("test_disaggregation.py", 499),
TestFile("test_disaggregation_different_tp.py", 155), TestFile("test_disaggregation_different_tp.py", 155),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment