test_vlm_accuracy.py 12.6 KB
Newer Older
1
2
3
4
5
"""
"""

import unittest
from io import BytesIO
6
from typing import List, Optional
7
8
9
10
11
12

import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
13
14
15
16
17
18
19
from transformers import (
    AutoModel,
    AutoProcessor,
    AutoTokenizer,
    Gemma3ForConditionalGeneration,
    Qwen2_5_VLForConditionalGeneration,
)
20

21
from sglang import Engine
22
23
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
Mick's avatar
Mick committed
24
from sglang.srt.managers.mm_utils import embed_mm_inputs
Mick's avatar
Mick committed
25
26
27
28
29
from sglang.srt.managers.schedule_batch import (
    Modality,
    MultimodalDataItem,
    MultimodalInputs,
)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs


# Test the logits output between HF and SGLang
class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
    @classmethod
    def setUpClass(cls):
        cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
        cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        cls.model_path = ""
        cls.chat_template = ""
        cls.processor = ""
        response = requests.get(cls.image_url)
        cls.main_image = Image.open(BytesIO(response.content))

    def compare_outputs(self, sglang_output: torch.Tensor, hf_output: torch.Tensor):
        # Convert to float32 for numerical stability if needed
        hf = hf_output.float()
        sg = sglang_output.float()

        # Basic shape and dtype comparison
        print("\n=== Basic Properties ===")
        print(f"Shapes match: {hf.shape == sg.shape}")
        print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}")
        print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}")

        # Move tensors to CPU for numpy operations
        hf_np = hf.cpu().numpy()
        sg_np = sg.cpu().numpy()

        # Statistical metrics
        print("\n=== Statistical Metrics ===")
        print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}")
        print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}")
        print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}")
        print(
            f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}"
        )

        # Cosine similarity (across feature dimension)
        cos_sim = F.cosine_similarity(hf, sg)
        print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}")
        print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}")

        # Find largest absolute differences
        print("\n=== Largest Absolute Differences ===")
        diffs = torch.abs(hf - sg)
        flat_diffs = diffs.flatten()

        # Get indices of top 10 differences
        top_k = 10
        top_values, top_flat_indices = torch.topk(flat_diffs, top_k)

        # Convert flat indices to multidimensional indices
        top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape)

        print(f"\nTop {top_k} largest absolute differences:")
        print(
            "Index".ljust(30)
            + "Difference".ljust(15)
            + "HF Value".ljust(15)
            + "SGLang Value"
        )
        print("-" * 75)

        for i in range(top_k):
            # Get the index tuple for this difference
            idx = tuple(dim[i] for dim in top_indices)
        diff_val = top_values[i].item()
        hf_val = hf[idx].item()
        sg_val = sg[idx].item()

        # Format the index tuple and values
        idx_str = str(idx)
        print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}")

        np.testing.assert_allclose(hf_np, sg_np)

110
    def get_completion_request(self) -> ChatCompletionRequest:
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        json_str = f"""
        {{
  "model": "{self.model_path}",
  "messages": [
    {{
      "role": "user",
      "content": [
        {{
          "type": "image_url",
          "image_url": {{
            "url": "{self.image_url}"
          }}
        }},
        {{
          "type": "text",
126
          "text": "What's in this picture?"
127
128
129
130
131
132
133
        }}
      ]
    }}
  ]
}}
        """

134
        return ChatCompletionRequest.model_validate_json(json_str)
135

136
137
138
    def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
        if req is None:
            req = self.get_completion_request()
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        conv = generate_chat_conv(req, template_name=self.chat_template)
        text = conv.get_prompt()

        # Process inputs using processor
        # FIXME: the formal arguments may differ
        inputs = self.processor(
            text=[text],
            images=[self.main_image],
            return_tensors="pt",
        ).to(self.device)

        return inputs

    def get_sglang_model(self):
Mick's avatar
Mick committed
153
        self.model_runner = ModelRunner(
154
155
156
157
158
            model_config=ModelConfig(self.model_path, model_override_args="{}"),
            mem_fraction_static=0.8,
            gpu_id=0,
            tp_rank=0,
            tp_size=1,
159
160
            pp_rank=0,
            pp_size=1,
161
162
163
164
165
166
            nccl_port=12435,
            server_args=ServerArgs(
                model_path=self.model_path,
                disable_cuda_graph=True,
            ),
        )
Mick's avatar
Mick committed
167
        return self.model_runner.model
168
169
170
171
172
173


class TestMiniCPMVLogits(VisionLLMLogitsBase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()
174
        cls.model_path = "openbmb/MiniCPM-V-2_6"
175
176
177
178
179
180
181
182
183
        cls.tokenizer = AutoTokenizer.from_pretrained(
            cls.model_path, trust_remote_code=True
        )
        cls.processor = AutoProcessor.from_pretrained(
            cls.model_path, trust_remote_code=True
        )
        cls.chat_template = "minicpmv"

        cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Mick's avatar
Mick committed
184
185
186
187
188
189
190
        cls.hf_model = (
            AutoModel.from_pretrained(
                cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
            )
            .eval()
            .to(cls.device)
        )
191

192
193
194
195
    async def test_vlm_embedding_output(self):
        """
        Compares the embedding output of vlm
        """
196
197
198
        inputs = self.get_processor_output()

        with torch.no_grad():
199
            # hf
200
201
202
203
204
205
            model_inputs = {
                "input_ids": inputs.input_ids,
                "image_bound": inputs.image_bound,
                "pixel_values": inputs.pixel_values,
                "tgt_sizes": inputs.tgt_sizes,
            }
Mick's avatar
Mick committed
206
            (hf_output, _) = self.hf_model.get_vllm_embedding(
207
208
209
210
                model_inputs,
            )
            hf_output = hf_output.squeeze(0)

211
            # sglang
212
213
            model = self.get_sglang_model()
            input_ids = inputs["input_ids"].to(self.device).flatten()
Mick's avatar
Mick committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

            pixel_values = inputs["pixel_values"]
            tgt_sizes = inputs["tgt_sizes"]
            pixel_values_flat: List[torch.Tensor] = []
            tgt_sizes_flat: List[torch.Tensor] = []
            for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
                # per image
                if len(pixel_b) != len(tgt_b):
                    raise ValueError(
                        "Inconsistent N lengths, found: "
                        f"{len(pixel_b)} vs {len(tgt_b)}"
                    )
                for pixel_n, tgt_n in zip(pixel_b, tgt_b):
                    pixel_values_flat += [pixel_n]
                    tgt_sizes_flat += [tgt_n]
Mick's avatar
Mick committed
229
            sglang_output = embed_mm_inputs(
Mick's avatar
Mick committed
230
231
232
233
234
235
236
237
238
                mm_inputs=MultimodalInputs(
                    mm_items=[
                        MultimodalDataItem(
                            pixel_values=pixel_values_flat,
                            tgt_size=tgt_sizes_flat,
                            modality=Modality.IMAGE,
                            pad_value=self.processor.tokenizer.unk_token_id,
                        )
                    ]
239
                ),
240
                input_ids=input_ids,
241
                input_embedding=model.get_input_embeddings(),
Mick's avatar
Mick committed
242
                image_data_embedding_func=model.get_image_feature,
243
244
245
                placeholder_tokens={
                    Modality.IMAGE: self.processor.tokenizer.unk_token_id,
                },
246
247
248
249
250
            )

        self.compare_outputs(sglang_output, hf_output)


251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
        cls.chat_template = "qwen2-vl"
        cls.processor = AutoProcessor.from_pretrained(
            cls.model_path, trust_remote_code=True, use_fast=True
        )
        cls.visual = (
            Qwen2_5_VLForConditionalGeneration.from_pretrained(
                cls.model_path, torch_dtype=torch.bfloat16
            )
            .eval()
            .visual.to(cls.device)
        )

    def setUp(self):
        self.engine = Engine(
            model_path=self.model_path,
            chat_template=self.chat_template,
            device=self.device.type,
            mem_fraction_static=0.8,
        )

    def tearDown(self):
        self.engine.shutdown()

    async def test_qwen_vl_understands_image(self):
        req = self.get_completion_request()
        conv = generate_chat_conv(req, template_name=self.chat_template)
        text = conv.get_prompt()
        output = await self.engine.async_generate(
            prompt=text,
            image_data=[self.main_image],
            sampling_params=dict(temperature=0.0),
        )
        self.assertIn("taxi", output["text"].lower())

    async def test_qwen_vl_understands_precomputed_features(self):
        req = self.get_completion_request()
        processor_output = self.get_processor_output(req=req)
        with torch.inference_mode():
            precomputed_features = self.visual(
                processor_output["pixel_values"], processor_output["image_grid_thw"]
            )
        output = await self.engine.async_generate(
            input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
            image_data=[
                dict(
                    modality="IMAGE",
                    image_grid_thws=processor_output["image_grid_thw"],
                    precomputed_features=precomputed_features,
                )
            ],
            sampling_params=dict(temperature=0.0),
        )
        self.assertIn("taxi", output["text"].lower())


class TestGemmaUnderstandsImage(VisionLLMLogitsBase):

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.model_path = "google/gemma-3-4b-it"
        cls.chat_template = "gemma-it"
        cls.processor = AutoProcessor.from_pretrained(
            cls.model_path, trust_remote_code=True, use_fast=True
        )
        model = Gemma3ForConditionalGeneration.from_pretrained(
            cls.model_path, torch_dtype=torch.bfloat16
        )
        cls.vision_tower = model.vision_tower.eval().to(cls.device)
        cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)

    @classmethod
    def visual(cls, pixel_values):
        vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
        image_features = cls.mm_projector(vision_outputs)
        return image_features

    def setUp(self):
        self.engine = Engine(
            model_path=self.model_path,
            chat_template=self.chat_template,
            device=self.device.type,
            mem_fraction_static=0.5,
            enable_multimodal=True,
        )

    def tearDown(self):
        self.engine.shutdown()

    async def test_gemma_understands_image(self):
        req = self.get_completion_request()
        conv = generate_chat_conv(req, template_name=self.chat_template)
        text = conv.get_prompt()
        output = await self.engine.async_generate(
            prompt=text,
            image_data=[self.main_image],
            sampling_params=dict(temperature=0.0),
        )
        self.assertIn("taxi", output["text"].lower())

    async def test_gemma_understands_precomputed_features(self):
        req = self.get_completion_request()
        processor_output = self.get_processor_output(req=req)
        with torch.inference_mode():
            precomputed_features = self.visual(processor_output["pixel_values"])
        output = await self.engine.async_generate(
            input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
            image_data=[
                dict(
                    modality="IMAGE",
                    precomputed_features=precomputed_features,
                )
            ],
            sampling_params=dict(temperature=0.0),
        )
        self.assertIn("taxi", output["text"].lower())


375
376
if __name__ == "__main__":
    unittest.main()