test_vlm_accuracy.py 11 KB
Newer Older
1
2
3
4
5
"""
"""

import unittest
from io import BytesIO
6
from typing import List, Optional
7
8
9
10
11
12

import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
13
from transformers import AutoModel, AutoProcessor, AutoTokenizer
14
15
16

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
17
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
18
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
Mick's avatar
Mick committed
19
20
21
22
23
from sglang.srt.managers.schedule_batch import (
    Modality,
    MultimodalDataItem,
    MultimodalInputs,
)
24
from sglang.srt.model_executor.model_runner import ModelRunner
25
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
26
27
28
29
30
31
32
33
34
from sglang.srt.server_args import ServerArgs


# Test the logits output between HF and SGLang
class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
    @classmethod
    def setUpClass(cls):
        cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
        cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
36
37
        cls.model_path = ""
        cls.chat_template = ""
        cls.processor = ""
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        response = requests.get(cls.image_url)
        cls.main_image = Image.open(BytesIO(response.content))

    def compare_outputs(self, sglang_output: torch.Tensor, hf_output: torch.Tensor):
        # Convert to float32 for numerical stability if needed
        hf = hf_output.float()
        sg = sglang_output.float()

        # Basic shape and dtype comparison
        print("\n=== Basic Properties ===")
        print(f"Shapes match: {hf.shape == sg.shape}")
        print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}")
        print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}")

        # Move tensors to CPU for numpy operations
        hf_np = hf.cpu().numpy()
        sg_np = sg.cpu().numpy()

        # Statistical metrics
        print("\n=== Statistical Metrics ===")
        print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}")
        print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}")
        print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}")
        print(
            f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}"
        )

        # Cosine similarity (across feature dimension)
        cos_sim = F.cosine_similarity(hf, sg)
        print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}")
        print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}")

        # Find largest absolute differences
        print("\n=== Largest Absolute Differences ===")
        diffs = torch.abs(hf - sg)
        flat_diffs = diffs.flatten()

        # Get indices of top 10 differences
        top_k = 10
        top_values, top_flat_indices = torch.topk(flat_diffs, top_k)

        # Convert flat indices to multidimensional indices
        top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape)

        print(f"\nTop {top_k} largest absolute differences:")
        print(
            "Index".ljust(30)
            + "Difference".ljust(15)
            + "HF Value".ljust(15)
            + "SGLang Value"
        )
        print("-" * 75)

        for i in range(top_k):
            # Get the index tuple for this difference
            idx = tuple(dim[i] for dim in top_indices)
        diff_val = top_values[i].item()
        hf_val = hf[idx].item()
        sg_val = sg[idx].item()

        # Format the index tuple and values
        idx_str = str(idx)
        print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}")

        np.testing.assert_allclose(hf_np, sg_np)

104
    def get_completion_request(self) -> ChatCompletionRequest:
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        json_str = f"""
        {{
  "model": "{self.model_path}",
  "messages": [
    {{
      "role": "user",
      "content": [
        {{
          "type": "image_url",
          "image_url": {{
            "url": "{self.image_url}"
          }}
        }},
        {{
          "type": "text",
120
          "text": "What's in this picture?"
121
122
123
124
125
126
127
        }}
      ]
    }}
  ]
}}
        """

128
        return ChatCompletionRequest.model_validate_json(json_str)
129

130
131
132
    def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
        if req is None:
            req = self.get_completion_request()
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        conv = generate_chat_conv(req, template_name=self.chat_template)
        text = conv.get_prompt()

        # Process inputs using processor
        # FIXME: the formal arguments may differ
        inputs = self.processor(
            text=[text],
            images=[self.main_image],
            return_tensors="pt",
        ).to(self.device)

        return inputs

    def get_sglang_model(self):
Mick's avatar
Mick committed
147
        self.model_runner = ModelRunner(
148
149
150
151
152
            model_config=ModelConfig(self.model_path, model_override_args="{}"),
            mem_fraction_static=0.8,
            gpu_id=0,
            tp_rank=0,
            tp_size=1,
153
154
            pp_rank=0,
            pp_size=1,
155
156
157
158
159
160
            nccl_port=12435,
            server_args=ServerArgs(
                model_path=self.model_path,
                disable_cuda_graph=True,
            ),
        )
Mick's avatar
Mick committed
161
        return self.model_runner.model
162
163


tc-mb's avatar
tc-mb committed
164
class TestMiniCPMV2_6Logits(VisionLLMLogitsBase):
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.model_path = "openbmb/MiniCPM-V-2_6"
        cls.tokenizer = AutoTokenizer.from_pretrained(
            cls.model_path, trust_remote_code=True
        )
        cls.processor = AutoProcessor.from_pretrained(
            cls.model_path, trust_remote_code=True
        )
        cls.chat_template = "minicpmv"

        cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        cls.hf_model = (
            AutoModel.from_pretrained(
                cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
            )
            .eval()
            .to(cls.device)
        )
        init_embedding_cache()

    async def test_vlm_embedding_output(self):
        """
        Compares the embedding output of vlm
        """
        inputs = self.get_processor_output()

        with torch.no_grad():
            # hf
            model_inputs = {
                "input_ids": inputs.input_ids,
                "image_bound": inputs.image_bound,
                "pixel_values": inputs.pixel_values,
                "tgt_sizes": inputs.tgt_sizes,
            }
            (hf_output, _) = self.hf_model.get_vllm_embedding(
                model_inputs,
            )
            hf_output = hf_output.squeeze(0)

            # sglang
            model = self.get_sglang_model()
            input_ids = inputs["input_ids"].to(self.device).flatten()

            pixel_values = inputs["pixel_values"]
            tgt_sizes = inputs["tgt_sizes"]
            pixel_values_flat: List[torch.Tensor] = []
            tgt_sizes_flat: List[torch.Tensor] = []
            for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
                # per image
                if len(pixel_b) != len(tgt_b):
                    raise ValueError(
                        "Inconsistent N lengths, found: "
                        f"{len(pixel_b)} vs {len(tgt_b)}"
                    )
                for pixel_n, tgt_n in zip(pixel_b, tgt_b):
                    pixel_values_flat += [pixel_n]
                    tgt_sizes_flat += [tgt_n]

            im_start_id, im_end_id = (
                self.tokenizer.im_start_id,
                self.tokenizer.im_end_id,
            )
            slice_start_id, slice_end_id = (
                self.tokenizer.slice_start_id,
                self.tokenizer.slice_end_id,
            )

            image_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
                input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
            )
            slice_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
                input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
            )
            image_offsets.extend(slice_offsets)
            image_offsets = sorted(image_offsets)

            sglang_output = embed_mm_inputs(
                mm_inputs_list=[
                    MultimodalInputs(
                        mm_items=[
                            MultimodalDataItem(
248
                                feature=pixel_values_flat,
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
                                offsets=image_offsets,
                                tgt_size=tgt_sizes_flat,
                                modality=Modality.IMAGE,
                                pad_value=self.processor.tokenizer.unk_token_id,
                            )
                        ]
                    ),
                ],
                extend_prefix_lens=[0],
                extend_seq_lens=[input_ids.shape[0]],
                input_ids=input_ids,
                input_embedding=model.get_input_embeddings(),
                multimodal_model=model,
                placeholder_tokens={
                    Modality.IMAGE: self.processor.tokenizer.unk_token_id,
                },
            )

        self.compare_outputs(sglang_output, hf_output)
tc-mb's avatar
tc-mb committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324


class TestMiniCPMV4Logits(VisionLLMLogitsBase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.model_path = "openbmb/MiniCPM-V-4"
        cls.tokenizer = AutoTokenizer.from_pretrained(
            cls.model_path, trust_remote_code=True
        )
        cls.processor = AutoProcessor.from_pretrained(
            cls.model_path, trust_remote_code=True
        )
        cls.chat_template = "minicpmv"

        cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        cls.hf_model = (
            AutoModel.from_pretrained(
                cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
            )
            .eval()
            .to(cls.device)
        )
        init_embedding_cache()

    async def test_vlm_embedding_output(self):
        """
        Compares the embedding output of vlm
        """
        inputs = self.get_processor_output()

        with torch.no_grad():
            # hf
            model_inputs = {
                "input_ids": inputs.input_ids,
                "image_bound": inputs.image_bound,
                "pixel_values": inputs.pixel_values,
                "tgt_sizes": inputs.tgt_sizes,
            }
            hf_output = self.hf_model.get_input_embeddings()(inputs.input_ids)

            # sglang
            model = self.get_model()
            sglang_output = self.vlm_func(
                model,
                input_ids=inputs.input_ids.to(self.device),
                pixel_values=inputs.pixel_values,
                image_bound=inputs.image_bound.to(self.device),
                tgt_sizes=inputs.tgt_sizes.to(self.device),
                input_embedding=model.get_input_embeddings(),
                multimodal_model=model,
                placeholder_tokens={
                    Modality.IMAGE: self.processor.tokenizer.unk_token_id,
                },
            )

        self.compare_outputs(sglang_output, hf_output)