llava.py 36.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""Inference-only LLaVa model compatible with HuggingFace weights."""
15

16
17
import math
import re
Kiv Chen's avatar
Kiv Chen committed
18
19
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
20
21
22

import numpy as np
import torch
Liangsheng Yin's avatar
Liangsheng Yin committed
23
from torch import nn
24
25
26
27
28
29
from transformers import (
    CLIPVisionConfig,
    CLIPVisionModel,
    LlavaConfig,
    MistralConfig,
    Qwen2Config,
30
    SiglipVisionModel,
31
)
Kiv Chen's avatar
Kiv Chen committed
32
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
Liangsheng Yin's avatar
Liangsheng Yin committed
33
34
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector

Kiv Chen's avatar
Kiv Chen committed
35
36
# leave till last and symbol only in case circular import
import sglang.srt.models as sgl_models
37
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Kiv Chen's avatar
Kiv Chen committed
38
39
40
41
42
43
from sglang.srt.managers.mm_utils import general_mm_embed_routine
from sglang.srt.managers.schedule_batch import (
    Modality,
    MultimodalDataItem,
    MultimodalInputs,
)
44
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
from sglang.srt.model_loader.weight_utils import default_weight_loader
46
from sglang.srt.models.llama import LlamaForCausalLM
47
from sglang.srt.models.mistral import MistralForCausalLM
48
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
49
50
51
52
53
from sglang.srt.multimodal.mm_utils import (
    get_anyres_image_grid_shape,
    unpad_image,
    unpad_image_shape,
)
Kiv Chen's avatar
Kiv Chen committed
54
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
Lianmin Zheng's avatar
Lianmin Zheng committed
55
56


57
class LlavaBaseForCausalLM(nn.Module):
Mick's avatar
Mick committed
58
    def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
Mick's avatar
Mick committed
59
60
61
62
63
        image_sizes = flatten_nested_list(
            [item.image_sizes for item in image_inputs.mm_items]
        )

        pad_values = [item.pad_value for item in image_inputs.mm_items]
Liangsheng Yin's avatar
Liangsheng Yin committed
64

65
        # hardcode for spatial_unpad + anyres
Mick's avatar
Mick committed
66
67
68
        if any(
            item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
            for item in image_inputs.mm_items
69
70
71
72
        ):
            image_aspect_ratio = "pad"
        else:
            image_aspect_ratio = "anyres"
73
        offset_list = []
74
        image_inputs.image_pad_len = []
75
        for image_idx, image_s in enumerate(image_sizes):
76
            if len(image_sizes) > 16:
77
78
79
80
81
                # 2x2 pooling with stride 2
                new_image_feature_len = (
                    math.ceil(self.image_size / self.patch_size / 2) ** 2
                )
            else:
Mick's avatar
Mick committed
82
                new_image_feature_len = self.image_feature_len  # multi-image
83

shiyi.c_98's avatar
shiyi.c_98 committed
84
            height = width = self.num_patches_per_side
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
            if "anyres" in image_aspect_ratio:
                num_patch_width, num_patch_height = get_anyres_image_grid_shape(
                    image_s,
                    self.image_grid_pinpoints,
                    self.vision_tower.config.image_size,
                )
                h = num_patch_height * height
                w = num_patch_width * width
                new_h, new_w = unpad_image_shape(h, w, image_s)

                if "anyres_max" in self.config.image_aspect_ratio:
                    matched_anyres_max_num_patches = re.match(
                        r"anyres_max_(\d+)", self.config.image_aspect_ratio
                    )
                    if matched_anyres_max_num_patches:
                        max_num_patches = int(matched_anyres_max_num_patches.group(1))
                    # times = math.sqrt(h * w / (max_num_patches * unit**2))
                    times = math.sqrt(
                        new_h * new_w / (max_num_patches * self.image_feature_len)
shiyi.c_98's avatar
shiyi.c_98 committed
104
                    )
105
106
107
108
109
110
111
112
113
114
115
116
                    if times > 1.1:
                        new_h = int(new_h // times)
                        new_w = int(new_w // times)
                new_image_feature_len += new_h * (new_w + 1)

            try:
                offset = input_ids.index(self.config.image_token_index)
            except ValueError:
                offset = 0
            # old_len + pad_len - 1, because we need to remove image_token_id
            input_ids = (
                input_ids[:offset]
Mick's avatar
Mick committed
117
                + [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
118
119
120
                + input_ids[offset + 1 :]
            )
            offset_list.append(offset)
121
            image_inputs.image_pad_len.append(new_image_feature_len)
Liangsheng Yin's avatar
Liangsheng Yin committed
122
123
124

        image_inputs.image_offsets = offset_list
        return input_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
125

Kiv Chen's avatar
Kiv Chen committed
126
127
128
129
130
131
132
133
134
135
    def encode_images(
        self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
    ) -> torch.Tensor:
        """
        encode images by vision tower and multimodal projector
        Args:
            pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
        Returns:
            torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
        """
shiyi.c_98's avatar
shiyi.c_98 committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
        # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
        selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
        if self.vision_feature_select_strategy in ["default", "patch"]:
            selected_image_feature = selected_image_feature[:, 1:]
        elif self.vision_feature_select_strategy == "full":
            selected_image_feature = selected_image_feature
        else:
            raise ValueError(
                f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
            )
        image_features = self.multi_modal_projector(selected_image_feature)
        return image_features

Liangsheng Yin's avatar
Liangsheng Yin committed
150
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
151
152
153
154
    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.Tensor,
155
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
156
    ) -> torch.Tensor:
Mick's avatar
Mick committed
157
        image_inputs = forward_batch.mm_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
158

159
        if forward_batch.forward_mode.is_extend():
160
161
162
163
164
165
166
167
            # Clamp input ids. This is because the input_ids for the image tokens are
            # filled with the hash values of the image for the prefix matching in the radix attention.
            # There values are useless because their embeddings will be replaced by vision embeddings anyway.
            input_ids.clamp_(min=0, max=self.config.vocab_size - 1)

            # Embed text inputs
            input_embeds = self.language_model.model.embed_tokens(input_ids)

168
169
170
            # Got List[List[str]] extend it to List[str]
            # The length of the List should be equal to batch size
            modalities_list = []
Liangsheng Yin's avatar
Liangsheng Yin committed
171
172
            max_image_offset = []
            for im in image_inputs:
Mick's avatar
Mick committed
173
174
                if im:
                    modalities_list.extend([item.modality for item in im.mm_items])
175
                if im and im.image_offsets:
176
177
178
                    max_image_offset.append(
                        np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
                    )
Liangsheng Yin's avatar
Liangsheng Yin committed
179
180
                else:
                    max_image_offset.append(-1)
Lianmin Zheng's avatar
Lianmin Zheng committed
181

182
            start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
183
            need_vision = start_positions <= np.array(max_image_offset)
Lianmin Zheng's avatar
Lianmin Zheng committed
184
185

            if need_vision.any():
186
                bs = forward_batch.batch_size
Mick's avatar
Mick committed
187
188
                pixel_values = flatten_nested_list(
                    [
189
                        [item.feature for item in image_inputs[i].mm_items]
Mick's avatar
Mick committed
190
191
192
193
                        for i in range(bs)
                        if need_vision[i]
                    ]
                )
Liangsheng Yin's avatar
Liangsheng Yin committed
194
                image_sizes = [
Mick's avatar
Mick committed
195
196
197
198
199
                    flatten_nested_list(
                        [item.image_sizes for item in image_inputs[i].mm_items]
                    )
                    for i in range(bs)
                    if need_vision[i]
Liangsheng Yin's avatar
Liangsheng Yin committed
200
                ]
Lianmin Zheng's avatar
Lianmin Zheng committed
201

shiyi.c_98's avatar
shiyi.c_98 committed
202
203
                ########## Encode Image ########

204
                if pixel_values[0].ndim == 4:
shiyi.c_98's avatar
shiyi.c_98 committed
205
                    # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
206
207
208
209
210
211
                    np.concatenate(pixel_values, axis=0)
                    # ndim=4
                    concat_images = torch.tensor(
                        np.concatenate(pixel_values, axis=0),
                        device=self.vision_tower.device,
                    )
shiyi.c_98's avatar
shiyi.c_98 committed
212
213
214
215
                    image_features = self.encode_images(concat_images)
                    split_sizes = [image.shape[0] for image in pixel_values]
                    image_features = torch.split(image_features, split_sizes, dim=0)
                    # hd image_features: BS, num_patch, 576, 4096
Lianmin Zheng's avatar
Lianmin Zheng committed
216
                else:
shiyi.c_98's avatar
shiyi.c_98 committed
217
                    # normal pixel: BS, C=3, H=336, W=336
218
219
220
                    pixel_values = torch.tensor(
                        np.array(pixel_values), device=self.vision_tower.device
                    )
shiyi.c_98's avatar
shiyi.c_98 committed
221
222
223
224
225
                    image_features = self.encode_images(pixel_values)
                    # image_features: BS, 576, 4096

                if self.mm_patch_merge_type.startswith("spatial"):
                    new_image_features = []
226
                    height = width = self.num_patches_per_side
shiyi.c_98's avatar
shiyi.c_98 committed
227
                    for image_idx, image_feature in enumerate(image_features):
Mick's avatar
Mick committed
228
                        if modalities_list[image_idx] == Modality.IMAGE:
229
230
231
                            image_aspect_ratio = (
                                self.config.image_aspect_ratio
                            )  # single image
232
                        elif (
Mick's avatar
Mick committed
233
234
                            modalities_list[image_idx] == Modality.MULTI_IMAGES
                            or modalities_list[image_idx] == Modality.VIDEO
235
                        ):
236
237
238
239
240
241
242
                            image_aspect_ratio = "pad"  # multi image
                        # image_aspect_ratio = (
                        #     "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
                        # )
                        if (
                            image_feature.shape[0] > 1
                            and "anyres" in image_aspect_ratio
Mick's avatar
Mick committed
243
                            and modalities_list[image_idx] == Modality.IMAGE
244
                        ):
shiyi.c_98's avatar
shiyi.c_98 committed
245
246
247
                            base_image_feature = image_feature[0]
                            image_feature = image_feature[1:]
                            assert height * width == base_image_feature.shape[0]
248
249
250
251

                            if "anyres_max" in image_aspect_ratio:
                                matched_anyres_max_num_patches = re.match(
                                    r"anyres_max_(\d+)", image_aspect_ratio
shiyi.c_98's avatar
shiyi.c_98 committed
252
                                )
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
                                if matched_anyres_max_num_patches:
                                    max_num_patches = int(
                                        matched_anyres_max_num_patches.group(1)
                                    )

                            if (
                                image_aspect_ratio == "anyres"
                                or "anyres_max" in image_aspect_ratio
                            ):
                                vision_tower_image_size = self.image_size
                                try:
                                    num_patch_width, num_patch_height = (
                                        get_anyres_image_grid_shape(
                                            image_sizes[image_idx][0],
                                            self.config.image_grid_pinpoints,
                                            vision_tower_image_size,
                                        )
                                    )
                                except Exception as e:
                                    print(f"Error: {e}")
                                    num_patch_width, num_patch_height = 2, 2
shiyi.c_98's avatar
shiyi.c_98 committed
274
275
276
277
                                image_feature = image_feature.view(
                                    num_patch_height, num_patch_width, height, width, -1
                                )
                            else:
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
                                image_feature = image_feature.view(
                                    2, 2, height, width, -1
                                )

                            # (
                            #     num_patch_width,
                            #     num_patch_height,
                            # ) = get_anyres_image_grid_shape(
                            #     image_sizes[image_idx][0],
                            #     self.image_grid_pinpoints,
                            #     self.vision_tower.config.image_size,
                            # )

                            # image_feature = image_feature.view(
                            #     num_patch_height, num_patch_width, height, width, -1
                            # )

shiyi.c_98's avatar
shiyi.c_98 committed
295
                            if "unpad" in self.mm_patch_merge_type:
296
                                unit = image_feature.shape[2]
shiyi.c_98's avatar
shiyi.c_98 committed
297
298
299
300
301
302
303
                                image_feature = image_feature.permute(
                                    4, 0, 2, 1, 3
                                ).contiguous()
                                image_feature = image_feature.flatten(1, 2).flatten(
                                    2, 3
                                )
                                image_feature = unpad_image(
304
                                    image_feature, image_sizes[image_idx][0]
shiyi.c_98's avatar
shiyi.c_98 committed
305
                                )
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                                if (
                                    "anyres_max" in image_aspect_ratio
                                    and matched_anyres_max_num_patches
                                ):
                                    c, h, w = image_feature.shape
                                    times = math.sqrt(
                                        h * w / (max_num_patches * unit**2)
                                    )
                                    if times > 1.1:
                                        image_feature = image_feature[None]
                                        image_feature = nn.functional.interpolate(
                                            image_feature,
                                            [int(h // times), int(w // times)],
                                            mode="bilinear",
                                        )[0]
shiyi.c_98's avatar
shiyi.c_98 committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
                                image_feature = torch.cat(
                                    (
                                        image_feature,
                                        self.language_model.model.image_newline[
                                            :, None, None
                                        ].expand(*image_feature.shape[:-1], 1),
                                    ),
                                    dim=-1,
                                )
                                image_feature = image_feature.flatten(1, 2).transpose(
                                    0, 1
                                )
                            else:
                                image_feature = image_feature.permute(
                                    0, 2, 1, 3, 4
                                ).contiguous()
                                image_feature = image_feature.flatten(0, 3)
                            image_feature = torch.cat(
                                (base_image_feature, image_feature), dim=0
                            )
341
                            image_feature = image_feature.unsqueeze(0)
shiyi.c_98's avatar
shiyi.c_98 committed
342
                        else:
Mick's avatar
Mick committed
343
                            if modalities_list[image_idx] == Modality.VIDEO:  # video
344
345
346
347
                                # 2x2 pooling
                                num_of_frames = image_feature.shape[0]
                                image_feature = image_feature.view(
                                    num_of_frames, height, width, -1
shiyi.c_98's avatar
shiyi.c_98 committed
348
                                )
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
                                image_feature = image_feature.permute(
                                    0, 3, 1, 2
                                ).contiguous()  # N, C, H, W
                                height, weight = image_feature.shape[2:]
                                scaled_shape = [
                                    math.ceil(height / 2),
                                    math.ceil(weight / 2),
                                ]
                                image_feature = nn.functional.interpolate(
                                    image_feature, size=scaled_shape, mode="bilinear"
                                )
                                image_feature = (
                                    image_feature.flatten(2)
                                    .transpose(1, 2)
                                    .contiguous()
                                )  # N, C, H*W
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
                            if "unpad" in self.mm_patch_merge_type:
                                image_feature = torch.cat(
                                    (
                                        image_feature,
                                        # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
                                        self.language_model.model.image_newline[
                                            None, None
                                        ].expand(
                                            image_feature.shape[0],
                                            1,
                                            image_feature.shape[-1],
                                        ),
                                    ),
                                    dim=1,
                                )
380

shiyi.c_98's avatar
shiyi.c_98 committed
381
382
                        new_image_features.append(image_feature)
                    image_features = new_image_features
Lianmin Zheng's avatar
Lianmin Zheng committed
383

384
                # Fill in the placeholder for the image
385
                extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
386
                extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
387
                prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
390
391
392
393
                pt = 0
                for i in range(bs):
                    if not need_vision[i]:
                        continue

                    start_idx = extend_start_loc_cpu[i]
394
                    seq_len = extend_seq_lens[i]
395
396
397
                    prefix_len = prefix_lens_cpu[i]

                    # Multiple images
398
399
400
401
402
403
404
                    for image_idx, image_offset in enumerate(
                        image_inputs[i].image_offsets
                    ):
                        if (
                            image_offset + image_inputs[i].image_pad_len[image_idx]
                            <= prefix_len
                        ):
405
                            continue
406
407
                        if image_offset >= prefix_len + seq_len:
                            break
408

409
                        tmp_image_feature = image_features[pt][image_idx]
410
411
                        pad_len = tmp_image_feature.shape[0]

412
413
414
415
416
417
418
419
420
421
422
423
                        input_offset = image_offset - prefix_len
                        left_idx = start_idx + input_offset
                        right_idx = left_idx + pad_len
                        assert right_idx > start_idx
                        if input_offset < 0:
                            left_idx = start_idx
                            tmp_image_feature = tmp_image_feature[-input_offset:]
                        if right_idx > start_idx + seq_len:
                            tmp_image_feature = tmp_image_feature[
                                : start_idx + seq_len - right_idx
                            ]
                            right_idx = start_idx + seq_len
424
425
426
427
428
429
430
431
                        try:
                            input_embeds[left_idx:right_idx] = tmp_image_feature
                        except RuntimeError as e:
                            print(f"RuntimeError in image encoding: {e}")
                            print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
                            print(
                                f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
                            )
Lianmin Zheng's avatar
Lianmin Zheng committed
432
433
434
                    pt += 1

            return self.language_model(
435
                input_ids, positions, forward_batch, input_embeds=input_embeds
Lianmin Zheng's avatar
Lianmin Zheng committed
436
            )
437
438
        elif forward_batch.forward_mode.is_decode():
            return self.language_model(input_ids, positions, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
439

440
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
441
442
443
        # Load clip vision model by cfg['mm_vision_tower']:
        # huggingface_name or path_of_clip_relative_to_llava_model_dir
        # We put the initialization here instead of __init__ to allow it being reused by other subclasses.
Lianmin Zheng's avatar
Lianmin Zheng committed
444
        vision_path = self.config.mm_vision_tower
445
446
447
448
449
450
451
452
453
454
        if "clip" in vision_path:
            self.vision_tower = CLIPVisionModel.from_pretrained(
                vision_path, torch_dtype=torch.float16
            ).cuda()
        elif "siglip" in vision_path:
            self.vision_tower = SiglipVisionModel.from_pretrained(
                vision_path, torch_dtype=torch.float16
            ).cuda()
            # Siglip needs all feature tokens
            self.config.mm_vision_select_feature = "full"
Lianmin Zheng's avatar
Lianmin Zheng committed
455
456
457
458
459
460
        self.vision_tower.eval()

        self.vision_feature_layer = self.config.mm_vision_select_layer
        self.vision_feature_select_strategy = self.config.mm_vision_select_feature
        self.image_size = self.vision_tower.config.image_size
        self.patch_size = self.vision_tower.config.patch_size
shiyi.c_98's avatar
shiyi.c_98 committed
461
462
463
464
465

        self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
        self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
        self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)

466
467
468
469
470
        self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
        if (
            self.vision_feature_select_strategy == "patch"
            or self.vision_feature_select_strategy == "full"
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
471
472
473
474
475
476
477
478
479
480
            pass
        elif self.vision_feature_select_strategy == "cls_patch":
            self.image_feature_len += 1
        else:
            raise ValueError(f"Unexpected select feature: {self.select_feature}")

        # load mm_projector
        projector_weights = {
            "model.mm_projector.0": "multi_modal_projector.linear_1",
            "model.mm_projector.2": "multi_modal_projector.linear_2",
481
482
            "model.vision_tower.vision_tower": "vision_tower",
            # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
483
            "model.image_newline": "language_model.model.image_newline",
Lianmin Zheng's avatar
Lianmin Zheng committed
484
485
        }
        params_dict = dict(self.named_parameters())
486
        for name, loaded_weight in weights:
487
            if "projector" in name or "vision_tower" in name or "image_newline" in name:
Lianmin Zheng's avatar
Lianmin Zheng committed
488
489
490
491
492
493
                for weight_name, param_name in projector_weights.items():
                    if weight_name in name:
                        name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
494
495
            else:
                self.language_model.load_weights([(name, loaded_weight)])
Lianmin Zheng's avatar
Lianmin Zheng committed
496

shiyi.c_98's avatar
shiyi.c_98 committed
497
498
499
500
    @property
    def num_patches_per_side(self):
        return self.image_size // self.patch_size

Lianmin Zheng's avatar
Lianmin Zheng committed
501

502
503
504
505
506
class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
    def __init__(
        self,
        config: LlavaConfig,
        quant_config: Optional[QuantizationConfig] = None,
507
        prefix: str = "",
508
509
510
511
512
513
514
    ) -> None:
        super().__init__()

        self.config = config
        self.vision_tower = None
        self.config.vision_config.hidden_size = config.mm_hidden_size
        self.config.text_config.hidden_size = config.hidden_size
515

516
        self.multi_modal_projector = LlavaMultiModalProjector(config)
517
518
519
520
521
        self.language_model = LlamaForCausalLM(
            config,
            quant_config=quant_config,
            prefix=add_prefix("language_model", prefix),
        )
522
523
524
525
526
527
528
        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
            self.language_model.model.image_newline = nn.Parameter(
                torch.empty(config.text_config.hidden_size, dtype=torch.float16)
            )


class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
529
530
531
532
    def __init__(
        self,
        config: LlavaConfig,
        quant_config: Optional[QuantizationConfig] = None,
533
        prefix: str = "",
534
    ) -> None:
535
536
        super().__init__()

537
538
        self.config = config
        self.vision_tower = None
539

540
541
542
543
544
545
546
547
548
549
550
551
552
553
        if getattr(self.config, "vision_config", None) is None:
            self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
        if getattr(self.config, "text_config", None) is None:
            self.config.text_config = Qwen2Config(self.config._name_or_path)

        self.config.vision_config.hidden_size = config.mm_hidden_size
        self.config.text_config.hidden_size = config.hidden_size

        if getattr(self.config, "projector_hidden_act", None) is None:
            self.config.projector_hidden_act = "gelu"
        if getattr(self.config, "image_token_index", None) is None:
            self.config.image_token_index = 151646

        self.multi_modal_projector = LlavaMultiModalProjector(config)
554
555
556
557
558
        self.language_model = Qwen2ForCausalLM(
            config,
            quant_config=quant_config,
            prefix=add_prefix("language_model", prefix),
        )
559
560
561
562
563
564
        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
            self.language_model.model.image_newline = nn.Parameter(
                torch.empty(config.text_config.hidden_size, dtype=torch.float16)
            )


565
class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
566
567
568
569
    def __init__(
        self,
        config: LlavaConfig,
        quant_config: Optional[QuantizationConfig] = None,
570
        prefix: str = "",
571
    ) -> None:
572
573
        super().__init__()

574
575
        self.config = config
        self.vision_tower = None
576

577
578
579
580
581
582
583
584
585
586
587
588
589
590
        if getattr(self.config, "vision_config", None) is None:
            self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
        if getattr(self.config, "text_config", None) is None:
            self.config.text_config = MistralConfig(self.config._name_or_path)

        self.config.vision_config.hidden_size = config.mm_hidden_size
        self.config.text_config.hidden_size = config.hidden_size

        if getattr(self.config, "projector_hidden_act", None) is None:
            self.config.projector_hidden_act = "gelu"
        if getattr(self.config, "image_token_index", None) is None:
            self.config.image_token_index = 32000

        self.multi_modal_projector = LlavaMultiModalProjector(config)
591
592
593
594
595
        self.language_model = MistralForCausalLM(
            config,
            quant_config=quant_config,
            prefix=add_prefix("language_model", prefix),
        )
596
597
598
599
600
601
        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
            self.language_model.model.image_newline = nn.Parameter(
                torch.empty(config.text_config.hidden_size, dtype=torch.float16)
            )


Kiv Chen's avatar
Kiv Chen committed
602
603
604
605
606
607
608
609
610
611
612
613
class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
    """
    An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
    It follows the structure of (vision_tower, multi_modal_projector, language_model)

    Once a model config is loaded, text_config and vision_config will be extracted, and
    LlavaForConditionalGeneration will load the language_model and vision_tower models
    according to config.
    """

    MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector

Kiv Chen's avatar
Kiv Chen committed
614
615
616
617
    @property
    def dtype(self):
        return self.torch_dtype

Kiv Chen's avatar
Kiv Chen committed
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
    def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
        if hasattr(self.vision_tower, "pad_input_ids"):
            return self.vision_tower.pad_input_ids(input_ids, image_inputs)
        else:
            return super().pad_input_ids(input_ids, image_inputs)

    def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
        """
        Get the SGLang model implementation class according to config.

        Args:
            config: The config object of the model.
            auto_model_type: The type of the auto model.

        Returns:
            The SGLang model implementation class.
        """
        config_cls_name = config.__class__.__name__
        arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
        if arch := arch_name_mapping.get(config_cls_name):
            if isinstance(arch, tuple):
                arch = arch[0]
                logger.warning(
                    f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
                )
            try:
                return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
            except Exception as e:
                raise ValueError(
                    f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
                )
        else:
            raise ValueError(
                f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
            )

    @lru_cache
    def _config_cls_name_to_arch_name_mapping(
        self, auto_model_type: Type[AutoModel]
    ) -> Dict[str, str]:
        mapping = {}
659
660
661
662
663
664
665
666
667
        for config_cls in auto_model_type._model_mapping.keys():
            archs = auto_model_type._model_mapping.get(config_cls, None)
            if archs is not None:
                if isinstance(archs, tuple):
                    mapping[config_cls.__name__] = tuple(
                        arch.__name__ for arch in archs
                    )
                else:
                    mapping[config_cls.__name__] = archs.__name__
Kiv Chen's avatar
Kiv Chen committed
668
669
670
671
672
673
674
675
676
677
678
679
680
        return mapping

    def __init__(
        self,
        config: LlavaConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        assert hasattr(config, "text_config")
        assert hasattr(config, "vision_config")
        self.config = config
Kiv Chen's avatar
Kiv Chen committed
681
682
683
684
685
686
687
688
        self.text_config = self.config.text_config
        self.vision_config = self.config.vision_config
        self.torch_dtype = getattr(self.config, "torch_dtype")

        if not getattr(self.text_config, "torch_dtype"):
            self.text_config.torch_dtype = self.torch_dtype
        if not getattr(self.vision_config, "torch_dtype"):
            self.vision_config.torch_dtype = self.torch_dtype
Kiv Chen's avatar
Kiv Chen committed
689
690

        if not hasattr(self.config, "vocab_size"):
Kiv Chen's avatar
Kiv Chen committed
691
            self.config.vocab_size = self.text_config.vocab_size
Kiv Chen's avatar
Kiv Chen committed
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        if not hasattr(self.config, "image_aspect_ratio"):
            self.config.image_aspect_ratio = "anyres"
        if not hasattr(self.config, "image_grid_pinpoints"):
            # from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
            # self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
            self.config.image_grid_pinpoints = [
                [96, 96],
                [224, 224],
                [384, 384],
                [512, 512],
                [768, 768],
                [1024, 1024],
            ]
        if not hasattr(self.config, "mm_patch_merge_type"):
            self.config.mm_patch_merge_type = "flat"
        if not hasattr(self.config, "image_token_index"):
            self.config.image_token_index = 10
        if not hasattr(self.config, "projector_hidden_act"):
            self.config.projector_hidden_act = "gelu"

Kiv Chen's avatar
Kiv Chen committed
712
        self.vision_feature_layer = getattr(self.config, "vision_feature_layer", -1)
Kiv Chen's avatar
Kiv Chen committed
713
        self.vision_feature_select_strategy = getattr(
Kiv Chen's avatar
Kiv Chen committed
714
            self.config, "vision_feature_select_strategy", "full"
Kiv Chen's avatar
Kiv Chen committed
715
        )
Kiv Chen's avatar
Kiv Chen committed
716
717
        self.image_size = self.vision_config.image_size
        self.patch_size = self.vision_config.patch_size
Kiv Chen's avatar
Kiv Chen committed
718

Kiv Chen's avatar
Kiv Chen committed
719
720
721
        self.mm_patch_merge_type = self.config.mm_patch_merge_type
        self.image_aspect_ratio = self.config.image_aspect_ratio
        self.image_grid_pinpoints = self.config.image_grid_pinpoints
Kiv Chen's avatar
Kiv Chen committed
722
723
724
725
726
727

        self.image_feature_len = int((self.image_size // self.patch_size) ** 2)

        self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)

        language_model_cls = self._get_sgl_model_cls(
Kiv Chen's avatar
Kiv Chen committed
728
            self.text_config, AutoModelForCausalLM
Kiv Chen's avatar
Kiv Chen committed
729
        )
Kiv Chen's avatar
Kiv Chen committed
730
        vision_model_cls = self._get_sgl_model_cls(self.vision_config, AutoModel)
Kiv Chen's avatar
Kiv Chen committed
731
        self.language_model = language_model_cls(
Kiv Chen's avatar
Kiv Chen committed
732
            self.text_config,
Kiv Chen's avatar
Kiv Chen committed
733
734
735
736
            quant_config=quant_config,
            prefix=add_prefix("language_model", prefix),
        )
        self.vision_tower = vision_model_cls(
Kiv Chen's avatar
Kiv Chen committed
737
            self.vision_config,
Kiv Chen's avatar
Kiv Chen committed
738
739
740
741
            quant_config=quant_config,
            prefix=add_prefix("vision_tower", prefix),
        )

Kiv Chen's avatar
Kiv Chen committed
742
        if "unpad" in getattr(self.config, "mm_patch_merge_type", ""):
Kiv Chen's avatar
Kiv Chen committed
743
            self.language_model.model.image_newline = nn.Parameter(
Kiv Chen's avatar
Kiv Chen committed
744
                torch.empty(self.text_config.hidden_size, dtype=self.torch_dtype)
Kiv Chen's avatar
Kiv Chen committed
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
            )

    def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        """Extract features from image inputs.

        Args:
            items: List of MultimodalDataItem objects containing image data
                Note that an item can be either "image" or "multi-images"

        Returns:
            torch.Tensor: features from image inputs, concatenated
        """
        features = []
        for item in items:
            # in each item, we assume pixel_values is always batched
760
            pixel_values, image_sizes = item.feature, item.image_sizes
Kiv Chen's avatar
Kiv Chen committed
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
            image_outputs = self.vision_tower(
                pixel_values, image_sizes, output_hidden_states=True
            )
            selected_image_feature = image_outputs.hidden_states[
                self.vision_feature_layer
            ]

            if self.vision_feature_select_strategy in ["default", "patch"]:
                selected_image_feature = selected_image_feature[:, 1:]
            elif self.vision_feature_select_strategy == "full":
                selected_image_feature = selected_image_feature
            else:
                raise ValueError(
                    f"Unexpected select feature: {self.vision_feature_select_strategy}"
                )
            features.append(
                self.multi_modal_projector(selected_image_feature.squeeze(0))
            )
        ret = torch.cat(features, dim=0)
        return ret

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
        get_embedding: bool = False,
    ):
        hidden_states = general_mm_embed_routine(
            input_ids=input_ids,
            forward_batch=forward_batch,
            get_embedding=get_embedding,
            language_model=self.language_model,
794
795
796
            data_embedding_funcs={
                Modality.IMAGE: self.get_image_feature,
            },
Kiv Chen's avatar
Kiv Chen committed
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
            placeholder_tokens=None,  # using mm_item.pad_value
            positions=positions,
        )

        return hidden_states

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights for LlavaForConditionalGeneration.

        Unlike the base class implementation, this one doesn't need to handle
        weight name remapping as the weights are already properly structured with
        'language_model' and 'vision_tower' prefixes in the safetensors files.
        """
        if (
            self.vision_feature_select_strategy == "patch"
            or self.vision_feature_select_strategy == "full"
        ):
            pass
        elif self.vision_feature_select_strategy == "cls_patch":
            self.image_feature_len += 1
        else:
            raise ValueError(
                f"Unexpected select feature: {self.vision_feature_select_strategy}"
            )

        # Create dictionaries for direct parameter loading
        params_dict = dict(self.named_parameters())

        # Load weights directly without remapping
        for name, loaded_weight in weights:
            for part in ("language_model", "vision_tower"):
                if name.startswith(part):
                    name = name[len(part + ".") :]
                    getattr(self, part).load_weights([(name, loaded_weight)])
                    break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)


EntryClass = [
    LlavaLlamaForCausalLM,
    LlavaQwenForCausalLM,
    LlavaMistralForCausalLM,
    LlavaForConditionalGeneration,
]