test_hicache_storage_benchmark.py 6.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
Benchmark tests for HiCache Storage functionality.
Usage:
    python3 -m pytest test/srt/hicache/test_hicache_storage_benchmark.py -v
"""

import time
import unittest
from types import SimpleNamespace
from typing import Dict

import requests
from test_hicache_storage_e2e import HiCacheStorageBaseTest

from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import is_in_ci, write_github_step_summary


class TestHiCacheStorageBenchmark(HiCacheStorageBaseTest):
    """Benchmark tests for HiCache Storage functionality"""

    @classmethod
    def _get_additional_server_args_and_env(cls):
        """Get additional server arguments specific to configuration - override in subclasses"""
        server_args = {"--tp-size": 2, "--hicache-ratio": 1.5}
        return server_args, {}

    def flush_cache(self) -> bool:
        """Flush device cache to force remote storage access"""
        try:
            response = requests.post(f"{self.base_url}/flush_cache", timeout=10)
            return response.status_code == 200
        except requests.RequestException:
            return False

    # === Accuracy Tests ===
    def test_eval_accuracy_with_cache_persistence(self):
        """Test eval accuracy with cache persistence across cache flushes"""
        print("\n=== Testing Eval Accuracy with Cache Persistence ===")

        # First evaluation - populate cache
        print("Phase 1: Running initial GSM8K evaluation to populate cache...")
        args_initial = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=400,
            max_new_tokens=512,
            parallel=32,
            host=f"http://{self.base_host}",
            port=int(self.base_port),
        )
        metrics_initial = run_eval_few_shot_gsm8k(args_initial)
        print(f"Evaluation metrics: {metrics_initial}")
        self.assertGreater(metrics_initial["accuracy"], 0.60)

        # Flush cache to force remote storage access
        print("Phase 2: Flushing device cache...")
        self.assertTrue(self.flush_cache(), "Cache flush should succeed")
        time.sleep(2)

        # Second evaluation - should use remote cache
        print("Phase 3: Running second GSM8K evaluation using remote cache...")

        start_time = time.time()
        metrics_cached = run_eval_few_shot_gsm8k(args_initial)
        cached_time = time.time() - start_time

        print(f"Cached evaluation completed in {cached_time:.2f}s")
        print(f"Cached accuracy: {metrics_cached['accuracy']:.3f}")
        print(f"Cached throughput: {metrics_cached['output_throughput']:.2f} token/s")

        # Verify accuracy consistency
        accuracy_diff = abs(metrics_initial["accuracy"] - metrics_cached["accuracy"])
        print(f"Accuracy difference: {accuracy_diff:.4f}")

        # Assertions
        self.assertGreater(
            metrics_initial["accuracy"], 0.5, "Initial accuracy should be reasonable"
        )
        self.assertGreater(
            metrics_cached["accuracy"], 0.5, "Cached accuracy should be reasonable"
        )
        self.assertLess(
            accuracy_diff, 0.05, "Accuracy should be consistent between cache states"
        )

        # Performance should be similar or better with cache
        throughput_ratio = (
            metrics_cached["output_throughput"] / metrics_initial["output_throughput"]
        )
        print(f"Throughput ratio (cached/initial): {throughput_ratio:.2f}")

        if is_in_ci():
            write_github_step_summary(
                f"### HiCache Storage Accuracy Test\n"
                f"Initial accuracy: {metrics_initial['accuracy']:.3f}\n"
                f"Cached accuracy: {metrics_cached['accuracy']:.3f}\n"
                f"Accuracy difference: {accuracy_diff:.4f}\n"
                f"Throughput ratio: {throughput_ratio:.2f}\n"
            )

    # === Performance Benchmark Tests ===

    def test_throughput_benchmark_with_hicache(self):
        """Benchmark throughput performance with HiCache enabled"""
        print("\n=== Benchmarking Throughput with HiCache ===")

        # throughput test
        res1 = self._run_throughput_benchmark(
            test_name="hicache_offline_throughput",
            num_prompts=200,
            request_rate=10,
            additional_args=[],
        )

        # Flush cache to force remote storage access
        print("Phase 2: Flushing device cache...")
        self.assertTrue(self.flush_cache(), "Cache flush should succeed")
        time.sleep(2)

        # Second benchmark, should use remote cache
        res2 = self._run_throughput_benchmark(
            test_name="hicache_online_throughput",
            num_prompts=400,
            request_rate=10,
            additional_args=[],
        )

        if is_in_ci():
            write_github_step_summary(
                f"### HiCache Storage FileBackend Benchmark Test\n"
                f"First time throughput: {res1['input_throughput']:.2f} token/s\n"
                f"Second time throughput: {res2['input_throughput']:.2f} token/s\n"
                f"First time TTFT: {res1['mean_ttft_ms']:.2f} ms\n"
                f"Second time TTFT: {res2['mean_ttft_ms']:.2f} ms\n"
            )

    def _run_throughput_benchmark(
        self,
        test_name: str,
        num_prompts: int,
        request_rate: float,
        dataset_name: str = "random",
        additional_args: list = None,
    ) -> Dict:
        """Helper method to run throughput benchmarks"""
        if additional_args is None:
            additional_args = []

        print(f"Running {test_name} benchmark...")
        start_time = time.time()

        try:
            # Use the existing server instead of launching a new one
            from sglang.bench_serving import run_benchmark
            from sglang.test.test_utils import get_benchmark_args

            args = get_benchmark_args(
                base_url=self.base_url,
                dataset_name=dataset_name,
                tokenizer=self.model,
                num_prompts=num_prompts,
                request_rate=request_rate,
                random_input_len=1024,
                random_output_len=64,
            )

            # Run benchmark
            result = run_benchmark(args)

            elapsed_time = time.time() - start_time
            print(f"{test_name} completed in {elapsed_time:.2f}s")
            print(
                f"Output throughput: {result.get('output_throughput', 0.0):.2f} token/s"
            )

            return result

        except Exception as e:
            print(f"Benchmark {test_name} failed: {e}")
            # Fallback to avoid hard failure; return minimal metrics
            return {
                "output_throughput": 0.0,
                "input_throughput": 0.0,
                "mean_ttft_ms": float("inf"),
                "mean_latency_ms": float("inf"),
                "p99_ttft_ms": float("inf"),
            }


if __name__ == "__main__":
    unittest.main(verbosity=2)