"src/diffusers/pipelines/flux/pipeline_flux_img2img.py" did not exist on "9b8c8605d14b4543db8e0169d4aac97ad95a2a27"
test_vlm_models.py 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import argparse
import glob
import json
import os
import random
import subprocess
import sys
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    CustomTestCase,
    is_in_ci,
    popen_launch_server,
)

# VLM models for testing
MODELS = [
22
    SimpleNamespace(model="google/gemma-3-27b-it", mmmu_accuracy=0.45),
23
24
25
26
    SimpleNamespace(
        model="Qwen/Qwen2.5-VL-3B-Instruct",
        mmmu_accuracy=0.4,
    ),
27
    SimpleNamespace(model="openbmb/MiniCPM-V-2_6", mmmu_accuracy=0.4),
28
29
]

Liangsheng Yin's avatar
Liangsheng Yin committed
30
# Set default mem_fraction_static to 0.8
31
DEFAULT_MEM_FRACTION_STATIC = 0.8
32

Liangsheng Yin's avatar
Liangsheng Yin committed
33

34
35
36
37
38
39
40
41
42
43
class TestVLMModels(CustomTestCase):
    parsed_args = None  # Class variable to store args

    @classmethod
    def setUpClass(cls):
        # Removed argument parsing from here
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.api_key = "sk-123456"
        cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH

44
        if cls.parsed_args is None:
Liangsheng Yin's avatar
Liangsheng Yin committed
45
46
47
            cls.parsed_args = SimpleNamespace(
                mem_fraction_static=DEFAULT_MEM_FRACTION_STATIC
            )
48

49
50
51
52
        # Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work.
        os.environ["OPENAI_API_KEY"] = cls.api_key
        os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    def _detect_eviction_in_logs(self, log_output):
        """Detect if eviction events occurred in the log output."""
        eviction_keywords = ["Cache eviction: evicted"]

        eviction_detected = False
        eviction_count = 0

        for line in log_output.split("\n"):
            if any(keyword in line for keyword in eviction_keywords):
                eviction_detected = True
                eviction_count += 1
                print(f"Eviction detected: {line.strip()}")

        return eviction_detected, eviction_count

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    def run_mmmu_eval(
        self,
        model_version: str,
        output_path: str,
        *,
        env: dict | None = None,
    ):
        """
        Evaluate a VLM on the MMMU validation set with lmms‑eval.
        Only `model_version` (checkpoint) and `chat_template` vary;
        We are focusing only on the validation set due to resource constraints.
        """
        # -------- fixed settings --------
        model = "openai_compatible"
        tp = 1
        tasks = "mmmu_val"
84
        batch_size = 32
85
86
87
88
        log_suffix = "openai_compatible"
        os.makedirs(output_path, exist_ok=True)

        # -------- compose --model_args --------
89
        model_args = f'model_version="{model_version}",' f"tp={tp}"
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

        # -------- build command list --------
        cmd = [
            "python3",
            "-m",
            "lmms_eval",
            "--model",
            model,
            "--model_args",
            model_args,
            "--tasks",
            tasks,
            "--batch_size",
            str(batch_size),
            "--log_samples",
            "--log_samples_suffix",
            log_suffix,
            "--output_path",
            str(output_path),
        ]

        subprocess.run(
            cmd,
            check=True,
            timeout=3600,
        )

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    def _run_vlm_mmmu_test(
        self,
        model,
        output_path,
        test_name="",
        custom_env=None,
        log_level="info",
        capture_output=False,
    ):
        """
        Common method to run VLM MMMU benchmark test.

        Args:
            model: Model to test
            output_path: Path for output logs
            test_name: Optional test name for logging
            custom_env: Optional custom environment variables
            log_level: Log level for server (default: "info")
            capture_output: Whether to capture server stdout/stderr
        """
        print(f"\nTesting model: {model.model}{test_name}")

        process = None
        mmmu_accuracy = 0  # Initialize to handle potential exceptions
        server_output = ""

        try:
            # Prepare environment variables
            process_env = os.environ.copy()
            if custom_env:
                process_env.update(custom_env)

            # Prepare stdout/stderr redirection if needed
            stdout_file = None
            stderr_file = None
            if capture_output:
                stdout_file = open("/tmp/server_stdout.log", "w")
                stderr_file = open("/tmp/server_stderr.log", "w")

            # Launch server for testing
            process = popen_launch_server(
                model.model,
                base_url=self.base_url,
                timeout=self.time_out,
                api_key=self.api_key,
                other_args=[
                    "--trust-remote-code",
                    "--cuda-graph-max-bs",
                    "32",
                    "--enable-multimodal",
                    "--mem-fraction-static",
                    str(self.parsed_args.mem_fraction_static),  # Use class variable
                    "--log-level",
                    log_level,
                ],
                env=process_env,
                return_stdout_stderr=(
                    (stdout_file, stderr_file) if capture_output else None
                ),
            )

            # Run evaluation
            self.run_mmmu_eval(model.model, output_path)

            # Get the result file
182
183
184
185
186
187
188
189
190
            # Search recursively for JSON result files (lmms-eval v0.4.1+ creates subdirectories)
            result_files = glob.glob(f"{output_path}/**/*.json", recursive=True)
            if not result_files:
                result_files = glob.glob(f"{output_path}/*.json")

            if not result_files:
                raise FileNotFoundError(f"No JSON result files found in {output_path}")

            result_file_path = result_files[0]
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

            with open(result_file_path, "r") as f:
                result = json.load(f)
                print(f"Result{test_name}\n: {result}")

            # Process the result
            mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
            print(
                f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}"
            )

            # Capture server output if requested
            if capture_output and process:
                server_output = self._read_output_from_files()

            # Assert performance meets expected threshold
            self.assertGreaterEqual(
                mmmu_accuracy,
                model.mmmu_accuracy,
                f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}",
            )

            return server_output

        except Exception as e:
            print(f"Error testing {model.model}{test_name}: {e}")
            self.fail(f"Test failed for {model.model}{test_name}: {e}")

        finally:
            # Ensure process cleanup happens regardless of success/failure
            if process is not None and process.poll() is None:
                print(f"Cleaning up process {process.pid}")
                try:
                    kill_process_tree(process.pid)
                except Exception as e:
                    print(f"Error killing process: {e}")

            # clean up temporary files
            if capture_output:
                if stdout_file:
                    stdout_file.close()
                if stderr_file:
                    stderr_file.close()
                for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]:
                    try:
                        if os.path.exists(filename):
                            os.remove(filename)
                    except Exception as e:
                        print(f"Error removing {filename}: {e}")

    def _read_output_from_files(self):
        output_lines = []

        log_files = [
            ("/tmp/server_stdout.log", "[STDOUT]"),
            ("/tmp/server_stderr.log", "[STDERR]"),
        ]
        for filename, tag in log_files:
            try:
                if os.path.exists(filename):
                    with open(filename, "r") as f:
                        for line in f:
                            output_lines.append(f"{tag} {line.rstrip()}")
            except Exception as e:
                print(f"Error reading {tag.lower()} file: {e}")

        return "\n".join(output_lines)

259
260
261
262
263
264
265
266
    def test_vlm_mmmu_benchmark(self):
        """Test VLM models against MMMU benchmark."""
        models_to_test = MODELS

        if is_in_ci():
            models_to_test = [random.choice(MODELS)]

        for model in models_to_test:
267
            self._run_vlm_mmmu_test(model, "./logs")
268

269
270
271
    def test_vlm_mmmu_benchmark_with_small_cache(self):
        """Test VLM models against MMMU benchmark with a small embedding cache to force eviction."""
        models_to_test = MODELS
272

273
274
        if is_in_ci():
            models_to_test = [random.choice(MODELS)]
275

276
277
        for model in models_to_test:
            custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"}
278

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            # Run the test with output capture
            server_output = self._run_vlm_mmmu_test(
                model,
                "./logs_small_cache",
                test_name=" with small embedding cache (evict test)",
                custom_env=custom_env,
                log_level="debug",  # Enable debug logging for eviction detection
                capture_output=True,  # Capture server output
            )

            # Print server output for debugging
            print("Server output:\n", server_output)

            # Analyze server output for eviction events
            eviction_detected, eviction_count = self._detect_eviction_in_logs(
                server_output
            )

            # Assert that eviction was detected (since we're using small cache)
            self.assertTrue(
                eviction_detected,
                f"Expected eviction events to be detected with small cache (5MB), but none found. "
                f"Cache size may be too large for the workload or eviction logic may not be working. "
                f"Total log content length: {len(server_output)} characters",
            )

            print(
                f"Eviction detection summary: {eviction_count} eviction events detected"
            )

            # Additional assertion: if eviction was detected, the test passed
            if eviction_detected:
                print("✅ Eviction logic successfully triggered and detected!")
312
313
314
315
316
317
318
319
320


if __name__ == "__main__":
    # Define and parse arguments here, before unittest.main
    parser = argparse.ArgumentParser(description="Test VLM models")
    parser.add_argument(
        "--mem-fraction-static",
        type=float,
        help="Static memory fraction for the model",
321
        default=DEFAULT_MEM_FRACTION_STATIC,
322
323
324
325
326
327
328
329
330
331
    )

    # Parse args intended for unittest
    args = parser.parse_args()

    # Store the parsed args object on the class
    TestVLMModels.parsed_args = args

    # Pass args to unittest
    unittest.main(argv=[sys.argv[0]])