test_vlm_models.py 5.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import argparse
import glob
import json
import os
import random
import subprocess
import sys
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    CustomTestCase,
    is_in_ci,
    popen_launch_server,
)

# VLM models for testing
MODELS = [
22
    SimpleNamespace(model="google/gemma-3-27b-it", mmmu_accuracy=0.45),
23
24
25
26
    SimpleNamespace(
        model="Qwen/Qwen2.5-VL-3B-Instruct",
        mmmu_accuracy=0.4,
    ),
27
    SimpleNamespace(model="openbmb/MiniCPM-V-2_6", mmmu_accuracy=0.4),
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
]


class TestVLMModels(CustomTestCase):
    parsed_args = None  # Class variable to store args

    @classmethod
    def setUpClass(cls):
        # Removed argument parsing from here
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.api_key = "sk-123456"
        cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH

        # Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work.
        os.environ["OPENAI_API_KEY"] = cls.api_key
        os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"

    def run_mmmu_eval(
        self,
        model_version: str,
        output_path: str,
        *,
        env: dict | None = None,
    ):
        """
        Evaluate a VLM on the MMMU validation set with lmms‑eval.
        Only `model_version` (checkpoint) and `chat_template` vary;
        We are focusing only on the validation set due to resource constraints.
        """
        # -------- fixed settings --------
        model = "openai_compatible"
        tp = 1
        tasks = "mmmu_val"
61
        batch_size = 2
62
63
64
65
        log_suffix = "openai_compatible"
        os.makedirs(output_path, exist_ok=True)

        # -------- compose --model_args --------
66
        model_args = f'model_version="{model_version}",' f"tp={tp}"
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

        # -------- build command list --------
        cmd = [
            "python3",
            "-m",
            "lmms_eval",
            "--model",
            model,
            "--model_args",
            model_args,
            "--tasks",
            tasks,
            "--batch_size",
            str(batch_size),
            "--log_samples",
            "--log_samples_suffix",
            log_suffix,
            "--output_path",
            str(output_path),
        ]

        subprocess.run(
            cmd,
            check=True,
            timeout=3600,
        )

    def test_vlm_mmmu_benchmark(self):
        """Test VLM models against MMMU benchmark."""
        models_to_test = MODELS

        if is_in_ci():
            models_to_test = [random.choice(MODELS)]

        for model in models_to_test:
            print(f"\nTesting model: {model.model}")

            process = None
            mmmu_accuracy = 0  # Initialize to handle potential exceptions

            try:
                # Launch server for testing
                process = popen_launch_server(
                    model.model,
                    base_url=self.base_url,
                    timeout=self.time_out,
                    api_key=self.api_key,
                    other_args=[
                        "--trust-remote-code",
116
117
118
                        "--cuda-graph-max-bs",
                        "32",
                        "--enable-multimodal",
119
120
121
122
123
124
                        "--mem-fraction-static",
                        str(self.parsed_args.mem_fraction_static),  # Use class variable
                    ],
                )

                # Run evaluation
125
                self.run_mmmu_eval(model.model, "./logs")
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

                # Get the result file
                result_file_path = glob.glob("./logs/*.json")[0]

                with open(result_file_path, "r") as f:
                    result = json.load(f)
                    print(f"Result \n: {result}")
                # Process the result
                mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
                print(f"Model {model.model} achieved accuracy: {mmmu_accuracy:.4f}")

                # Assert performance meets expected threshold
                self.assertGreaterEqual(
                    mmmu_accuracy,
                    model.mmmu_accuracy,
                    f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f})",
                )

            except Exception as e:
                print(f"Error testing {model.model}: {e}")
                self.fail(f"Test failed for {model.model}: {e}")

            finally:
                # Ensure process cleanup happens regardless of success/failure
                if process is not None and process.poll() is None:
                    print(f"Cleaning up process {process.pid}")
                    try:
                        kill_process_tree(process.pid)
                    except Exception as e:
                        print(f"Error killing process: {e}")


if __name__ == "__main__":
    # Define and parse arguments here, before unittest.main
    parser = argparse.ArgumentParser(description="Test VLM models")
    parser.add_argument(
        "--mem-fraction-static",
        type=float,
        help="Static memory fraction for the model",
165
        default=0.8,
166
167
168
169
170
171
172
173
174
175
    )

    # Parse args intended for unittest
    args = parser.parse_args()

    # Store the parsed args object on the class
    TestVLMModels.parsed_args = args

    # Pass args to unittest
    unittest.main(argv=[sys.argv[0]])