llava.py 23.6 KB
Newer Older
1
from functools import cached_property
2
from types import MethodType
3
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
4
                    Tuple, TypedDict, Union)
5
6

import torch
7
import torch.nn as nn
8
9
10
11
12
from PIL.Image import Image
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
                          ProcessorMixin, SiglipVisionConfig)
from transformers.models.pixtral import PixtralProcessor
13
14

from vllm.attention import AttentionMetadata
15
from vllm.config import VllmConfig
16
from vllm.inputs import InputContext
17
from vllm.model_executor.layers.activation import get_act_fn
18
19
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
20
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
21
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
22
from vllm.model_executor.sampling_metadata import SamplingMetadata
23
from vllm.multimodal import MULTIMODAL_REGISTRY
24
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
25
26
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        InputProcessingContext,
27
28
                                        ModalityProcessingMetadata,
                                        MultiModalProcessingMetadata,
29
                                        PromptReplacement)
30
from vllm.sequence import IntermediateTensors
31

32
from .clip import (CLIPVisionModel, dummy_image_for_clip,
33
                   get_max_clip_image_tokens)
34
from .interfaces import SupportsMultiModal, SupportsPP
35
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
36
                      get_max_pixtral_hf_image_tokens)
37
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
38
                     get_max_siglip_image_tokens)
39
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
40
                    maybe_prefix, merge_multimodal_embeddings)
41
42


43
44
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
45
46
47
48
49
50
51
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
52
53
54
55
56


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
57
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
58
59
60
61
62
63
64
65

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


66
67
class LlavaMultiModalProjector(nn.Module):

68
69
70
71
72
73
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
74
75
        super().__init__()

76
77
78
79
80
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
                                             bias=True,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
81
        self.act = get_act_fn(projector_hidden_act)
82
83
84
85
86
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
87

88
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
89
        hidden_states, _ = self.linear_1(image_features)
90
        hidden_states = self.act(hidden_states)
91
        hidden_states, _ = self.linear_2(hidden_states)
92
93
94
        return hidden_states


95
96
97
98
99
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
100
101
102
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
103
104
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
105
106
107
108
109
110
111
112
113
114
115
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
116
117


118
119
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
                              mm_counts: Mapping[str, int]):
120
121
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
122
    num_images = mm_counts["image"]
123
124

    if isinstance(vision_config, CLIPVisionConfig):
125
        data = dummy_image_for_clip(vision_config, num_images)
126
    elif isinstance(vision_config, SiglipVisionConfig):
127
        data = dummy_image_for_siglip(vision_config, num_images)
128
    elif isinstance(vision_config, PixtralVisionConfig):
129
130
131
132
        data = dummy_image_for_pixtral_hf(vision_config, num_images)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)
133

134
135
136
137
    hf_processor = ctx.get_hf_processor()
    image_processor = hf_processor.image_processor  # type: ignore
    hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
    is_pixtral = isinstance(hf_processor, PixtralProcessor)
138

139
140
141
142
    return MultiModalKwargs(
        **hf_inputs,
        is_pixtral=torch.tensor(is_pixtral),
    )
143

144

145
146
def create_metadata_for_llava(
        ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
147
    hf_config = ctx.get_hf_config(LlavaConfig)
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    image_token_id = hf_config.image_token_index

    def get_repl_count(
        mm_items: list[Image],
        hf_inputs: BatchFeature,
        item_idx: int,
    ) -> int:
        return get_max_llava_image_tokens(ctx)

    return {
        "image":
        ModalityProcessingMetadata(prompt_repls=[
            PromptReplacement(target=[image_token_id],
                              repl_unit=[image_token_id],
                              repl_count=get_repl_count),
        ]),
    }
165

166

167
168
169
170
171
172
173
class LlavaProcessor(BaseMultiModalProcessor):

    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__(
            ctx=ctx,
            metadata=create_metadata_for_llava(ctx),
        )
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
        if getattr(hf_processor, "__is_patched__", False):
            return  # Already patched

        image_processor = hf_processor.image_processor  # type: ignore
        orig_preprocess = image_processor.preprocess

        def preprocess(__self, *args, **kwargs):
            hf_inputs = orig_preprocess(*args, **kwargs)
            hf_inputs["is_pixtral"] = torch.tensor(True)
            return hf_inputs

        image_processor.preprocess = MethodType(preprocess, image_processor)

        hf_processor.__is_patched__ = True  # type: ignore

    def _get_hf_processor(self) -> ProcessorMixin:
        hf_processor = self.ctx.get_hf_processor()

        if isinstance(hf_processor, PixtralProcessor):
            self._patch_pixtral_processor(hf_processor)

        return hf_processor

    def _get_dummy_mm_kwargs(
        self,
        mm_counts: Mapping[str, int],
    ) -> MultiModalKwargs:
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        vision_config = hf_config.vision_config
        num_images = mm_counts["image"]

        if isinstance(vision_config, CLIPVisionConfig):
            data = dummy_image_for_clip(vision_config, num_images)
        elif isinstance(vision_config, SiglipVisionConfig):
            data = dummy_image_for_siglip(vision_config, num_images)
        elif isinstance(vision_config, PixtralVisionConfig):
            data = dummy_image_for_pixtral_hf(vision_config, num_images)
        else:
            msg = f"Unsupported vision config: {type(vision_config)}"
            raise NotImplementedError(msg)

        hf_processor = self._get_hf_processor()
        image_processor = hf_processor.image_processor  # type: ignore
        hf_inputs = image_processor.preprocess(data['image'],
                                               return_tensors="pt")
        is_pixtral = isinstance(hf_processor, PixtralProcessor)

        return MultiModalKwargs(
            **hf_inputs,
            is_pixtral=torch.tensor(is_pixtral),
        )
227
228


229
230
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    vision_feature_layer: Union[int, List[int]]


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
266
267
268
269
270
271
272


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
273
    prefix: str = "",
274
):
275
276
    vision_config = hf_config.vision_config

277
278
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
279
280
281
282

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
283
            quant_config=quant_config,
284
            num_hidden_layers_override=num_hidden_layers,
285
            require_post_norm=require_post_norm,
286
            prefix=prefix,
287
288
289
290
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
291
            quant_config=quant_config,
292
            num_hidden_layers_override=num_hidden_layers,
293
            require_post_norm=require_post_norm,
294
            prefix=prefix,
295
        )
296
    elif isinstance(vision_config, PixtralVisionConfig):
297
298
        return PixtralHFVisionModel(
            vision_config,
299
            quant_config=quant_config,
300
301
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
302
            prefix=prefix,
303
        )
304
305
306
307
308

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


309
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
310
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
311
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
312
313
314
315
316
317
318
319
320
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
321

322
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
323
        super().__init__()
324

325
326
327
328
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

329
        self.config = config
330
        self.multimodal_config = multimodal_config
331

332
333
334
335
336
337
338
339
340
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

341
        # TODO: Optionally initializes this for supporting embeddings.
342
        self.vision_tower = init_vision_tower_for_llava(
343
344
345
            config,
            quant_config,
            require_post_norm=False,
346
            prefix=maybe_prefix(prefix, "vision_tower"))
347
348
349
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
350
351
352
            projector_hidden_act=config.projector_hidden_act,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
353

354
        self.language_model = init_vllm_registered_model(
355
            vllm_config=vllm_config,
356
357
358
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
359

360
361
362
363
364
365
366
367
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
368
        return get_sampler()
369

370
371
372
373
374
375
376
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
377
            raise ValueError(
378
379
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
380
381
382
383

        return data

    def _parse_and_validate_image_input(
384
385
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
386
        is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
387
        image_embeds = kwargs.pop("image_embeds", None)
388

389
        if pixel_values is None and image_embeds is None:
390
            return None
391

392
        if pixel_values is not None:
393
            if not isinstance(pixel_values, (torch.Tensor, list)):
394
395
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
396

397
398
            assert isinstance(is_pixtral, torch.Tensor)
            if is_pixtral.any():
399
                images = pixel_values
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418

                def flatten_to_3d_tensors(item):
                    if isinstance(item, torch.Tensor):
                        if item.dim() >= 3:
                            return [t for t in item.view(-1, *item.shape[-3:])]
                        else:
                            raise ValueError(
                                f"Unexpected tensor dimension: {item.dim()}")
                    elif isinstance(item, list):
                        return [
                            t for subitem in item
                            for t in flatten_to_3d_tensors(subitem)
                        ]
                    else:
                        raise ValueError(f"Unexpected type: {type(item)}")

                # Restructure the batched images into a list of lists of images
                images = flatten_to_3d_tensors(pixel_values)

419
420
                return LlavaImagePixelInputs(
                    type="pixel_values",
421
                    data=images,
422
423
                )

424
425
            return LlavaImagePixelInputs(
                type="pixel_values",
426
427
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
428
429
430
            )

        if image_embeds is not None:
431
            if not isinstance(image_embeds, (torch.Tensor, list)):
432
433
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
434

435
436
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
437
                data=flatten_bn(image_embeds, concat=True),
438
439
440
            )

        raise AssertionError("This line should be unreachable.")
441
442
443
444
445
446
447
448
449
450
451

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

452
453
    def _image_pixels_to_features(
        self,
454
455
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
456
457
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
458

459
460
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
461
        image_features = vision_tower(pixel_values)
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
478
479
480
481

        if image_input["type"] == "image_embeds":
            return image_input["data"]

482
483
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
484
485
        return self.multi_modal_projector(image_features)

486
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
487
488
489
490
491
492
493
494
495
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
496
        multimodal_embeddings: Optional[NestedTensors] = None,
497
498
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
499
        if multimodal_embeddings is not None:
500
            inputs_embeds = merge_multimodal_embeddings(
501
                input_ids, inputs_embeds, multimodal_embeddings,
502
503
504
                self.config.image_token_index)
        return inputs_embeds

505
506
507
508
509
510
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
511
        intermediate_tensors: Optional[IntermediateTensors] = None,
512
        inputs_embeds: Optional[torch.Tensor] = None,
513
        **kwargs: object,
514
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
515
        """Run forward pass for LLaVA-1.5.
516
517
518

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
519

520
        Concretely, consider a text prompt:
521
522
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

523
        Tokenizer outputs:
524
525
526
527
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
528
        before they are inputted to the model, so the input processor prepends
529
530
531
532
533
534
535
536
537
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
538
539
540
541
542
543
544

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
545
            pixel_values: The pixels in each input image.
546

547
548
        See also:
            :class:`LlavaImageInputs`
549
        """
550
551
        if intermediate_tensors is not None:
            inputs_embeds = None
552
553
554

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
555
        elif inputs_embeds is None:
556
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
557
558
559
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
560

561
562
563
564
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
565
                                                  intermediate_tensors,
566
                                                  inputs_embeds=inputs_embeds)
567
568
569

        return hidden_states

570
571
572
573
574
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
575
576
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
577
578
579
580
581
582

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
583
        return self.language_model.sample(logits, sampling_metadata)
584

585
586
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
587
        loader = AutoWeightsLoader(self)
588
        return loader.load_weights(weights)
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613


class MantisProcessor(LlavaProcessor):

    def _get_hf_processor(self) -> ProcessorMixin:
        try:
            from mantis.models.mllava import MLlavaProcessor
        except ModuleNotFoundError as exc:
            raise ModuleNotFoundError(
                "You need to `pip install "
                "git+https://github.com/TIGER-AI-Lab/Mantis.git` "
                "to use this model") from exc

        processor = MLlavaProcessor.from_pretrained(
            self.ctx.model_config.tokenizer)
        assert isinstance(processor, ProcessorMixin)
        return processor


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass