test_vlm_models.py 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import argparse
import glob
import json
import os
import random
import subprocess
import sys
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    CustomTestCase,
    is_in_ci,
    popen_launch_server,
)

# VLM models for testing
MODELS = [
22
    SimpleNamespace(model="google/gemma-3-27b-it", mmmu_accuracy=0.45),
23
24
25
26
    SimpleNamespace(
        model="Qwen/Qwen2.5-VL-3B-Instruct",
        mmmu_accuracy=0.4,
    ),
27
    SimpleNamespace(model="openbmb/MiniCPM-V-2_6", mmmu_accuracy=0.4),
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
]


class TestVLMModels(CustomTestCase):
    parsed_args = None  # Class variable to store args

    @classmethod
    def setUpClass(cls):
        # Removed argument parsing from here
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.api_key = "sk-123456"
        cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH

        # Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work.
        os.environ["OPENAI_API_KEY"] = cls.api_key
        os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    def _detect_eviction_in_logs(self, log_output):
        """Detect if eviction events occurred in the log output."""
        eviction_keywords = ["Cache eviction: evicted"]

        eviction_detected = False
        eviction_count = 0

        for line in log_output.split("\n"):
            if any(keyword in line for keyword in eviction_keywords):
                eviction_detected = True
                eviction_count += 1
                print(f"Eviction detected: {line.strip()}")

        return eviction_detected, eviction_count

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    def run_mmmu_eval(
        self,
        model_version: str,
        output_path: str,
        *,
        env: dict | None = None,
    ):
        """
        Evaluate a VLM on the MMMU validation set with lmms‑eval.
        Only `model_version` (checkpoint) and `chat_template` vary;
        We are focusing only on the validation set due to resource constraints.
        """
        # -------- fixed settings --------
        model = "openai_compatible"
        tp = 1
        tasks = "mmmu_val"
76
        batch_size = 2
77
78
79
80
        log_suffix = "openai_compatible"
        os.makedirs(output_path, exist_ok=True)

        # -------- compose --model_args --------
81
        model_args = f'model_version="{model_version}",' f"tp={tp}"
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

        # -------- build command list --------
        cmd = [
            "python3",
            "-m",
            "lmms_eval",
            "--model",
            model,
            "--model_args",
            model_args,
            "--tasks",
            tasks,
            "--batch_size",
            str(batch_size),
            "--log_samples",
            "--log_samples_suffix",
            log_suffix,
            "--output_path",
            str(output_path),
        ]

        subprocess.run(
            cmd,
            check=True,
            timeout=3600,
        )

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    def _run_vlm_mmmu_test(
        self,
        model,
        output_path,
        test_name="",
        custom_env=None,
        log_level="info",
        capture_output=False,
    ):
        """
        Common method to run VLM MMMU benchmark test.

        Args:
            model: Model to test
            output_path: Path for output logs
            test_name: Optional test name for logging
            custom_env: Optional custom environment variables
            log_level: Log level for server (default: "info")
            capture_output: Whether to capture server stdout/stderr
        """
        print(f"\nTesting model: {model.model}{test_name}")

        process = None
        mmmu_accuracy = 0  # Initialize to handle potential exceptions
        server_output = ""

        try:
            # Prepare environment variables
            process_env = os.environ.copy()
            if custom_env:
                process_env.update(custom_env)

            # Prepare stdout/stderr redirection if needed
            stdout_file = None
            stderr_file = None
            if capture_output:
                stdout_file = open("/tmp/server_stdout.log", "w")
                stderr_file = open("/tmp/server_stderr.log", "w")

            # Launch server for testing
            process = popen_launch_server(
                model.model,
                base_url=self.base_url,
                timeout=self.time_out,
                api_key=self.api_key,
                other_args=[
                    "--trust-remote-code",
                    "--cuda-graph-max-bs",
                    "32",
                    "--enable-multimodal",
                    "--mem-fraction-static",
                    str(self.parsed_args.mem_fraction_static),  # Use class variable
                    "--log-level",
                    log_level,
                ],
                env=process_env,
                return_stdout_stderr=(
                    (stdout_file, stderr_file) if capture_output else None
                ),
            )

            # Run evaluation
            self.run_mmmu_eval(model.model, output_path)

            # Get the result file
            result_file_path = glob.glob(f"{output_path}/*.json")[0]

            with open(result_file_path, "r") as f:
                result = json.load(f)
                print(f"Result{test_name}\n: {result}")

            # Process the result
            mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
            print(
                f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}"
            )

            # Capture server output if requested
            if capture_output and process:
                server_output = self._read_output_from_files()

            # Assert performance meets expected threshold
            self.assertGreaterEqual(
                mmmu_accuracy,
                model.mmmu_accuracy,
                f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}",
            )

            return server_output

        except Exception as e:
            print(f"Error testing {model.model}{test_name}: {e}")
            self.fail(f"Test failed for {model.model}{test_name}: {e}")

        finally:
            # Ensure process cleanup happens regardless of success/failure
            if process is not None and process.poll() is None:
                print(f"Cleaning up process {process.pid}")
                try:
                    kill_process_tree(process.pid)
                except Exception as e:
                    print(f"Error killing process: {e}")

            # clean up temporary files
            if capture_output:
                if stdout_file:
                    stdout_file.close()
                if stderr_file:
                    stderr_file.close()
                for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]:
                    try:
                        if os.path.exists(filename):
                            os.remove(filename)
                    except Exception as e:
                        print(f"Error removing {filename}: {e}")

    def _read_output_from_files(self):
        output_lines = []

        log_files = [
            ("/tmp/server_stdout.log", "[STDOUT]"),
            ("/tmp/server_stderr.log", "[STDERR]"),
        ]
        for filename, tag in log_files:
            try:
                if os.path.exists(filename):
                    with open(filename, "r") as f:
                        for line in f:
                            output_lines.append(f"{tag} {line.rstrip()}")
            except Exception as e:
                print(f"Error reading {tag.lower()} file: {e}")

        return "\n".join(output_lines)

243
244
245
246
247
248
249
250
    def test_vlm_mmmu_benchmark(self):
        """Test VLM models against MMMU benchmark."""
        models_to_test = MODELS

        if is_in_ci():
            models_to_test = [random.choice(MODELS)]

        for model in models_to_test:
251
            self._run_vlm_mmmu_test(model, "./logs")
252

253
254
255
    def test_vlm_mmmu_benchmark_with_small_cache(self):
        """Test VLM models against MMMU benchmark with a small embedding cache to force eviction."""
        models_to_test = MODELS
256

257
258
        if is_in_ci():
            models_to_test = [random.choice(MODELS)]
259

260
261
        for model in models_to_test:
            custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"}
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
            # Run the test with output capture
            server_output = self._run_vlm_mmmu_test(
                model,
                "./logs_small_cache",
                test_name=" with small embedding cache (evict test)",
                custom_env=custom_env,
                log_level="debug",  # Enable debug logging for eviction detection
                capture_output=True,  # Capture server output
            )

            # Print server output for debugging
            print("Server output:\n", server_output)

            # Analyze server output for eviction events
            eviction_detected, eviction_count = self._detect_eviction_in_logs(
                server_output
            )

            # Assert that eviction was detected (since we're using small cache)
            self.assertTrue(
                eviction_detected,
                f"Expected eviction events to be detected with small cache (5MB), but none found. "
                f"Cache size may be too large for the workload or eviction logic may not be working. "
                f"Total log content length: {len(server_output)} characters",
            )

            print(
                f"Eviction detection summary: {eviction_count} eviction events detected"
            )

            # Additional assertion: if eviction was detected, the test passed
            if eviction_detected:
                print("✅ Eviction logic successfully triggered and detected!")
296
297
298
299
300
301
302
303
304


if __name__ == "__main__":
    # Define and parse arguments here, before unittest.main
    parser = argparse.ArgumentParser(description="Test VLM models")
    parser.add_argument(
        "--mem-fraction-static",
        type=float,
        help="Static memory fraction for the model",
305
        default=0.8,
306
307
308
309
310
311
312
313
314
315
    )

    # Parse args intended for unittest
    args = parser.parse_args()

    # Store the parsed args object on the class
    TestVLMModels.parsed_args = args

    # Pass args to unittest
    unittest.main(argv=[sys.argv[0]])