llava.py 24.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""Inference-only LLaVa model compatible with HuggingFace weights."""
15

16
17
import math
import re
18
from typing import Iterable, List, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
21

import numpy as np
import torch
Liangsheng Yin's avatar
Liangsheng Yin committed
22
from torch import nn
23
24
25
26
27
28
from transformers import (
    CLIPVisionConfig,
    CLIPVisionModel,
    LlavaConfig,
    MistralConfig,
    Qwen2Config,
29
    SiglipVisionModel,
30
)
Liangsheng Yin's avatar
Liangsheng Yin committed
31
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Liangsheng Yin's avatar
Liangsheng Yin committed
33

34
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
35
from sglang.srt.managers.schedule_batch import ImageInputs
shiyi.c_98's avatar
shiyi.c_98 committed
36
37
38
39
40
from sglang.srt.mm_utils import (
    get_anyres_image_grid_shape,
    unpad_image,
    unpad_image_shape,
)
41
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
from sglang.srt.models.llama import LlamaForCausalLM
43
from sglang.srt.models.mistral import MistralForCausalLM
44
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46


47
class LlavaBaseForCausalLM(nn.Module):
Liangsheng Yin's avatar
Liangsheng Yin committed
48
49
50
    def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
        image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values

51
        # hardcode for spatial_unpad + anyres
52
53
54
55
56
57
58
        if image_inputs.modalities is not None and (
            "multi-images" in image_inputs.modalities
            or "video" in image_inputs.modalities
        ):
            image_aspect_ratio = "pad"
        else:
            image_aspect_ratio = "anyres"
59
        offset_list = []
60
        for image_idx, image_s in enumerate(image_sizes):
61
            if len(image_sizes) > 16:
62
63
64
65
66
67
68
                # 2x2 pooling with stride 2
                new_image_feature_len = (
                    math.ceil(self.image_size / self.patch_size / 2) ** 2
                )
            else:
                new_image_feature_len = self.image_feature_len  # multiimage

shiyi.c_98's avatar
shiyi.c_98 committed
69
            height = width = self.num_patches_per_side
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
            if "anyres" in image_aspect_ratio:
                num_patch_width, num_patch_height = get_anyres_image_grid_shape(
                    image_s,
                    self.image_grid_pinpoints,
                    self.vision_tower.config.image_size,
                )
                h = num_patch_height * height
                w = num_patch_width * width
                new_h, new_w = unpad_image_shape(h, w, image_s)

                if "anyres_max" in self.config.image_aspect_ratio:
                    matched_anyres_max_num_patches = re.match(
                        r"anyres_max_(\d+)", self.config.image_aspect_ratio
                    )
                    if matched_anyres_max_num_patches:
                        max_num_patches = int(matched_anyres_max_num_patches.group(1))
                    # times = math.sqrt(h * w / (max_num_patches * unit**2))
                    times = math.sqrt(
                        new_h * new_w / (max_num_patches * self.image_feature_len)
shiyi.c_98's avatar
shiyi.c_98 committed
89
                    )
90
91
92
93
94
95
96
97
98
99
100
101
                    if times > 1.1:
                        new_h = int(new_h // times)
                        new_w = int(new_w // times)
                new_image_feature_len += new_h * (new_w + 1)

            try:
                offset = input_ids.index(self.config.image_token_index)
            except ValueError:
                offset = 0
            # old_len + pad_len - 1, because we need to remove image_token_id
            input_ids = (
                input_ids[:offset]
102
                + [pad_values[image_idx]] * new_image_feature_len
103
104
105
                + input_ids[offset + 1 :]
            )
            offset_list.append(offset)
Liangsheng Yin's avatar
Liangsheng Yin committed
106
107
108

        image_inputs.image_offsets = offset_list
        return input_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
109

shiyi.c_98's avatar
shiyi.c_98 committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
        image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
        # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.

        selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
        if self.vision_feature_select_strategy in ["default", "patch"]:
            selected_image_feature = selected_image_feature[:, 1:]
        elif self.vision_feature_select_strategy == "full":
            selected_image_feature = selected_image_feature
        else:
            raise ValueError(
                f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
            )
        image_features = self.multi_modal_projector(selected_image_feature)

        return image_features

Liangsheng Yin's avatar
Liangsheng Yin committed
127
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
130
131
    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.Tensor,
132
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
133
    ) -> torch.Tensor:
134
        image_inputs = forward_batch.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
135

136
137
        if forward_batch.forward_mode.is_extend():
            bs = forward_batch.batch_size
138
139
140
            # Got List[List[str]] extend it to List[str]
            # The length of the List should be equal to batch size
            modalities_list = []
Liangsheng Yin's avatar
Liangsheng Yin committed
141
142
143
144
145
146
147
148
            max_image_offset = []
            for im in image_inputs:
                if im and im.modalities is not None:
                    modalities_list.extend(im.modalities)
                if im and im.image_offsets is not None:
                    max_image_offset.append(max(im.image_offsets))
                else:
                    max_image_offset.append(-1)
Lianmin Zheng's avatar
Lianmin Zheng committed
149

150
151
152
153
154
            # Clamp input ids. This is because the input_ids for the image tokens are
            # filled with the hash values of the image for the prefix matching in the radix attention.
            # There values are useless because their embeddings will be replaced by vision embeddings anyway.
            input_ids.clamp_(min=0, max=self.config.vocab_size - 1)

155
            # Embed text inputs
Lianmin Zheng's avatar
Lianmin Zheng committed
156
            input_embeds = self.language_model.model.embed_tokens(input_ids)
157

158
            start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
159
            need_vision = start_positions <= np.array(max_image_offset)
Lianmin Zheng's avatar
Lianmin Zheng committed
160
161

            if need_vision.any():
Liangsheng Yin's avatar
Liangsheng Yin committed
162
163
164
165
166
167
                pixel_values = [
                    image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
                ]
                image_sizes = [
                    image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
                ]
Lianmin Zheng's avatar
Lianmin Zheng committed
168

shiyi.c_98's avatar
shiyi.c_98 committed
169
170
                ########## Encode Image ########

171
                if pixel_values[0].ndim == 4:
shiyi.c_98's avatar
shiyi.c_98 committed
172
                    # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
173
174
175
176
177
178
                    np.concatenate(pixel_values, axis=0)
                    # ndim=4
                    concat_images = torch.tensor(
                        np.concatenate(pixel_values, axis=0),
                        device=self.vision_tower.device,
                    )
shiyi.c_98's avatar
shiyi.c_98 committed
179
180
181
182
                    image_features = self.encode_images(concat_images)
                    split_sizes = [image.shape[0] for image in pixel_values]
                    image_features = torch.split(image_features, split_sizes, dim=0)
                    # hd image_features: BS, num_patch, 576, 4096
Lianmin Zheng's avatar
Lianmin Zheng committed
183
                else:
shiyi.c_98's avatar
shiyi.c_98 committed
184
                    # normal pixel: BS, C=3, H=336, W=336
185
186
187
                    pixel_values = torch.tensor(
                        np.array(pixel_values), device=self.vision_tower.device
                    )
shiyi.c_98's avatar
shiyi.c_98 committed
188
189
190
191
192
                    image_features = self.encode_images(pixel_values)
                    # image_features: BS, 576, 4096

                if self.mm_patch_merge_type.startswith("spatial"):
                    new_image_features = []
193
                    height = width = self.num_patches_per_side
shiyi.c_98's avatar
shiyi.c_98 committed
194
                    for image_idx, image_feature in enumerate(image_features):
195
                        if modalities_list[image_idx] == "image":
196
197
198
                            image_aspect_ratio = (
                                self.config.image_aspect_ratio
                            )  # single image
199
200
201
202
                        elif (
                            modalities_list[image_idx] == "multi-images"
                            or modalities_list[image_idx] == "video"
                        ):
203
204
205
206
207
208
209
                            image_aspect_ratio = "pad"  # multi image
                        # image_aspect_ratio = (
                        #     "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
                        # )
                        if (
                            image_feature.shape[0] > 1
                            and "anyres" in image_aspect_ratio
210
                            and modalities_list[image_idx] == "image"
211
                        ):
shiyi.c_98's avatar
shiyi.c_98 committed
212
213
214
                            base_image_feature = image_feature[0]
                            image_feature = image_feature[1:]
                            assert height * width == base_image_feature.shape[0]
215
216
217
218

                            if "anyres_max" in image_aspect_ratio:
                                matched_anyres_max_num_patches = re.match(
                                    r"anyres_max_(\d+)", image_aspect_ratio
shiyi.c_98's avatar
shiyi.c_98 committed
219
                                )
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                                if matched_anyres_max_num_patches:
                                    max_num_patches = int(
                                        matched_anyres_max_num_patches.group(1)
                                    )

                            if (
                                image_aspect_ratio == "anyres"
                                or "anyres_max" in image_aspect_ratio
                            ):
                                vision_tower_image_size = self.image_size
                                try:
                                    num_patch_width, num_patch_height = (
                                        get_anyres_image_grid_shape(
                                            image_sizes[image_idx][0],
                                            self.config.image_grid_pinpoints,
                                            vision_tower_image_size,
                                        )
                                    )
                                except Exception as e:
                                    print(f"Error: {e}")
                                    num_patch_width, num_patch_height = 2, 2
shiyi.c_98's avatar
shiyi.c_98 committed
241
242
243
244
                                image_feature = image_feature.view(
                                    num_patch_height, num_patch_width, height, width, -1
                                )
                            else:
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                                image_feature = image_feature.view(
                                    2, 2, height, width, -1
                                )

                            # (
                            #     num_patch_width,
                            #     num_patch_height,
                            # ) = get_anyres_image_grid_shape(
                            #     image_sizes[image_idx][0],
                            #     self.image_grid_pinpoints,
                            #     self.vision_tower.config.image_size,
                            # )

                            # image_feature = image_feature.view(
                            #     num_patch_height, num_patch_width, height, width, -1
                            # )

shiyi.c_98's avatar
shiyi.c_98 committed
262
                            if "unpad" in self.mm_patch_merge_type:
263
                                unit = image_feature.shape[2]
shiyi.c_98's avatar
shiyi.c_98 committed
264
265
266
267
268
269
270
                                image_feature = image_feature.permute(
                                    4, 0, 2, 1, 3
                                ).contiguous()
                                image_feature = image_feature.flatten(1, 2).flatten(
                                    2, 3
                                )
                                image_feature = unpad_image(
271
                                    image_feature, image_sizes[image_idx][0]
shiyi.c_98's avatar
shiyi.c_98 committed
272
                                )
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
                                if (
                                    "anyres_max" in image_aspect_ratio
                                    and matched_anyres_max_num_patches
                                ):
                                    c, h, w = image_feature.shape
                                    times = math.sqrt(
                                        h * w / (max_num_patches * unit**2)
                                    )
                                    if times > 1.1:
                                        image_feature = image_feature[None]
                                        image_feature = nn.functional.interpolate(
                                            image_feature,
                                            [int(h // times), int(w // times)],
                                            mode="bilinear",
                                        )[0]
shiyi.c_98's avatar
shiyi.c_98 committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
                                image_feature = torch.cat(
                                    (
                                        image_feature,
                                        self.language_model.model.image_newline[
                                            :, None, None
                                        ].expand(*image_feature.shape[:-1], 1),
                                    ),
                                    dim=-1,
                                )
                                image_feature = image_feature.flatten(1, 2).transpose(
                                    0, 1
                                )
                            else:
                                image_feature = image_feature.permute(
                                    0, 2, 1, 3, 4
                                ).contiguous()
                                image_feature = image_feature.flatten(0, 3)
                            image_feature = torch.cat(
                                (base_image_feature, image_feature), dim=0
                            )
308
                            image_feature = image_feature.unsqueeze(0)
shiyi.c_98's avatar
shiyi.c_98 committed
309
                        else:
310
                            if modalities_list[image_idx] == "video":  # video
311
312
313
314
                                # 2x2 pooling
                                num_of_frames = image_feature.shape[0]
                                image_feature = image_feature.view(
                                    num_of_frames, height, width, -1
shiyi.c_98's avatar
shiyi.c_98 committed
315
                                )
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
                                image_feature = image_feature.permute(
                                    0, 3, 1, 2
                                ).contiguous()  # N, C, H, W
                                height, weight = image_feature.shape[2:]
                                scaled_shape = [
                                    math.ceil(height / 2),
                                    math.ceil(weight / 2),
                                ]
                                image_feature = nn.functional.interpolate(
                                    image_feature, size=scaled_shape, mode="bilinear"
                                )
                                image_feature = (
                                    image_feature.flatten(2)
                                    .transpose(1, 2)
                                    .contiguous()
                                )  # N, C, H*W
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
                            if "unpad" in self.mm_patch_merge_type:
                                image_feature = torch.cat(
                                    (
                                        image_feature,
                                        # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
                                        self.language_model.model.image_newline[
                                            None, None
                                        ].expand(
                                            image_feature.shape[0],
                                            1,
                                            image_feature.shape[-1],
                                        ),
                                    ),
                                    dim=1,
                                )
347

shiyi.c_98's avatar
shiyi.c_98 committed
348
349
                        new_image_features.append(image_feature)
                    image_features = new_image_features
Lianmin Zheng's avatar
Lianmin Zheng committed
350

351
                # Fill in the placeholder for the image
352
                extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
353
                prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
Lianmin Zheng's avatar
Lianmin Zheng committed
354
355
356
357
358
359
                pt = 0
                for i in range(bs):
                    if not need_vision[i]:
                        continue

                    start_idx = extend_start_loc_cpu[i]
360
361
362
                    prefix_len = prefix_lens_cpu[i]

                    # Multiple images
363
                    for j, image_offset in enumerate(image_inputs[i].image_offsets):
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
                        if image_offset < prefix_len:
                            continue

                        tmp_image_feature = image_features[pt][j]
                        pad_len = tmp_image_feature.shape[0]

                        left_idx = start_idx + (image_offset - prefix_len)
                        right_idx = start_idx + (image_offset - prefix_len) + pad_len
                        try:
                            input_embeds[left_idx:right_idx] = tmp_image_feature
                        except RuntimeError as e:
                            print(f"RuntimeError in image encoding: {e}")
                            print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
                            print(
                                f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
                            )
Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
382
                    pt += 1

            return self.language_model(
383
                input_ids, positions, forward_batch, input_embeds=input_embeds
Lianmin Zheng's avatar
Lianmin Zheng committed
384
            )
385
386
        elif forward_batch.forward_mode.is_decode():
            return self.language_model(input_ids, positions, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
387

388
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
389
390
391
        # Load clip vision model by cfg['mm_vision_tower']:
        # huggingface_name or path_of_clip_relative_to_llava_model_dir
        # We put the initialization here instead of __init__ to allow it being reused by other subclasses.
Lianmin Zheng's avatar
Lianmin Zheng committed
392
        vision_path = self.config.mm_vision_tower
393
394
395
396
397
398
399
400
401
402
        if "clip" in vision_path:
            self.vision_tower = CLIPVisionModel.from_pretrained(
                vision_path, torch_dtype=torch.float16
            ).cuda()
        elif "siglip" in vision_path:
            self.vision_tower = SiglipVisionModel.from_pretrained(
                vision_path, torch_dtype=torch.float16
            ).cuda()
            # Siglip needs all feature tokens
            self.config.mm_vision_select_feature = "full"
Lianmin Zheng's avatar
Lianmin Zheng committed
403
404
405
406
407
408
        self.vision_tower.eval()

        self.vision_feature_layer = self.config.mm_vision_select_layer
        self.vision_feature_select_strategy = self.config.mm_vision_select_feature
        self.image_size = self.vision_tower.config.image_size
        self.patch_size = self.vision_tower.config.patch_size
shiyi.c_98's avatar
shiyi.c_98 committed
409
410
411
412
413

        self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
        self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
        self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)

414
415
416
417
418
        self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
        if (
            self.vision_feature_select_strategy == "patch"
            or self.vision_feature_select_strategy == "full"
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
419
420
421
422
423
424
425
426
427
428
            pass
        elif self.vision_feature_select_strategy == "cls_patch":
            self.image_feature_len += 1
        else:
            raise ValueError(f"Unexpected select feature: {self.select_feature}")

        # load mm_projector
        projector_weights = {
            "model.mm_projector.0": "multi_modal_projector.linear_1",
            "model.mm_projector.2": "multi_modal_projector.linear_2",
shiyi.c_98's avatar
shiyi.c_98 committed
429
            "model.vision_tower.vision_tower": "vision_tower",  # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
430
            "model.image_newline": "language_model.model.image_newline",
Lianmin Zheng's avatar
Lianmin Zheng committed
431
432
        }
        params_dict = dict(self.named_parameters())
433
        for name, loaded_weight in weights:
434
            if "projector" in name or "vision_tower" in name or "image_newline" in name:
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
437
438
439
440
                for weight_name, param_name in projector_weights.items():
                    if weight_name in name:
                        name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
441
442
            else:
                self.language_model.load_weights([(name, loaded_weight)])
Lianmin Zheng's avatar
Lianmin Zheng committed
443

shiyi.c_98's avatar
shiyi.c_98 committed
444
445
446
447
    @property
    def num_patches_per_side(self):
        return self.image_size // self.patch_size

Lianmin Zheng's avatar
Lianmin Zheng committed
448

449
450
451
452
453
class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
    def __init__(
        self,
        config: LlavaConfig,
        quant_config: Optional[QuantizationConfig] = None,
454
        cache_config=None,
455
456
457
458
459
460
461
    ) -> None:
        super().__init__()

        self.config = config
        self.vision_tower = None
        self.config.vision_config.hidden_size = config.mm_hidden_size
        self.config.text_config.hidden_size = config.hidden_size
462

463
464
465
466
467
468
469
470
471
        self.multi_modal_projector = LlavaMultiModalProjector(config)
        self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
            self.language_model.model.image_newline = nn.Parameter(
                torch.empty(config.text_config.hidden_size, dtype=torch.float16)
            )


class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
472
473
474
475
    def __init__(
        self,
        config: LlavaConfig,
        quant_config: Optional[QuantizationConfig] = None,
476
        cache_config=None,
477
    ) -> None:
478
479
        super().__init__()

480
481
        self.config = config
        self.vision_tower = None
482

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        if getattr(self.config, "vision_config", None) is None:
            self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
        if getattr(self.config, "text_config", None) is None:
            self.config.text_config = Qwen2Config(self.config._name_or_path)

        self.config.vision_config.hidden_size = config.mm_hidden_size
        self.config.text_config.hidden_size = config.hidden_size

        if getattr(self.config, "projector_hidden_act", None) is None:
            self.config.projector_hidden_act = "gelu"
        if getattr(self.config, "image_token_index", None) is None:
            self.config.image_token_index = 151646

        self.multi_modal_projector = LlavaMultiModalProjector(config)
        self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
            self.language_model.model.image_newline = nn.Parameter(
                torch.empty(config.text_config.hidden_size, dtype=torch.float16)
            )


504
class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
505
506
507
508
    def __init__(
        self,
        config: LlavaConfig,
        quant_config: Optional[QuantizationConfig] = None,
509
        cache_config=None,
510
    ) -> None:
511
512
        super().__init__()

513
514
        self.config = config
        self.vision_tower = None
515

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        if getattr(self.config, "vision_config", None) is None:
            self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
        if getattr(self.config, "text_config", None) is None:
            self.config.text_config = MistralConfig(self.config._name_or_path)

        self.config.vision_config.hidden_size = config.mm_hidden_size
        self.config.text_config.hidden_size = config.hidden_size

        if getattr(self.config, "projector_hidden_act", None) is None:
            self.config.projector_hidden_act = "gelu"
        if getattr(self.config, "image_token_index", None) is None:
            self.config.image_token_index = 32000

        self.multi_modal_projector = LlavaMultiModalProjector(config)
        self.language_model = MistralForCausalLM(config, quant_config=quant_config)
        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
            self.language_model.model.image_newline = nn.Parameter(
                torch.empty(config.text_config.hidden_size, dtype=torch.float16)
            )


537
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]