llava.py 24.8 KB
Newer Older
1
from functools import cached_property
2
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
3
                    Tuple, TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
8
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
9
                          SiglipVisionConfig)
10
from transformers.models.llava import LlavaProcessor
11
from transformers.models.pixtral import PixtralProcessor
12
13

from vllm.attention import AttentionMetadata
14
from vllm.config import VllmConfig
15
from vllm.inputs import InputContext
16
from vllm.model_executor.layers.activation import get_act_fn
17
18
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
19
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
20
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
21
from vllm.model_executor.sampling_metadata import SamplingMetadata
22
from vllm.multimodal import MULTIMODAL_REGISTRY
23
24
25
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
                                    MultiModalFieldConfig, MultiModalInputsV2,
                                    MultiModalKwargs, NestedTensors)
26
from vllm.multimodal.processing import (BaseMultiModalProcessor,
27
28
                                        ProcessorInputs, PromptReplacement,
                                        full_groupby_modality)
29
from vllm.sequence import IntermediateTensors
30

31
from .clip import (CLIPVisionModel, dummy_image_for_clip,
32
                   get_max_clip_image_tokens)
33
from .interfaces import SupportsMultiModal, SupportsPP
34
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
35
36
                      get_max_pixtral_hf_image_tokens,
                      get_pixtral_hf_image_feature_size)
37
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
38
                     get_max_siglip_image_tokens)
39
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
40
                    maybe_prefix, merge_multimodal_embeddings)
41
42


43
44
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
45
46
47
48
49
50
51
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
52
53
54
55
56


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
57
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
58
59
60
61
62
63
64
65

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


66
67
class LlavaMultiModalProjector(nn.Module):

68
69
70
71
72
73
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
74
75
        super().__init__()

76
77
78
79
80
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
                                             bias=True,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
81
        self.act = get_act_fn(projector_hidden_act)
82
83
84
85
86
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
87

88
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
89
        hidden_states, _ = self.linear_1(image_features)
90
        hidden_states = self.act(hidden_states)
91
        hidden_states, _ = self.linear_2(hidden_states)
92
93
94
        return hidden_states


95
96
97
98
99
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
100
101
102
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
103
104
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
105
106
107
108
109
110
111
112
113
114
115
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
116
117


118
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
119

120
121
    def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
        return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
122

123
124
125
126
127
128
129
130
131
132
133
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
134

135
136
137
138
139
        # NOTE: pixel_values=None for MLlavaProcessor
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
            images = mm_data["images"]
            assert isinstance(images, list)
140

141
142
143
144
145
146
147
            if isinstance(self._get_hf_processor(), PixtralProcessor):
                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1
                        and isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))
148

149
                processed_outputs["pixel_values"] = pixel_values[0]
150

151
        return processed_outputs
152

153
154
155
156
157
158
159
160
161
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )
162

163
164
165
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
166
167
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    ) -> list[PromptReplacement]:
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        image_token_id = hf_config.image_token_index

        processor = self._get_hf_processor()
        if isinstance(processor, PixtralProcessor):
            image_token = processor.image_token
            image_break_token = processor.image_break_token
            image_end_token = processor.image_end_token

            vision_config = hf_config.vision_config
            assert isinstance(vision_config, PixtralVisionConfig)

            def get_replacement_pixtral(item_idx: int):
                image_size = mm_items.get_image_size(item_idx)
                (
                    num_width_tokens,
                    num_height_tokens,
                ) = get_pixtral_hf_image_feature_size(
                    vision_config,
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

                tokens = ([image_token] * num_width_tokens +
                          [image_break_token]) * num_height_tokens
                tokens[-1] = image_end_token

                return "".join(tokens)

            return [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id],
                    replacement=get_replacement_pixtral,
                ),
            ]

        max_image_tokens = get_max_llava_image_tokens(self.ctx)

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=[image_token_id] * max_image_tokens,
            )
        ]

    def _get_dummy_mm_inputs(
217
218
        self,
        mm_counts: Mapping[str, int],
219
    ) -> ProcessorInputs:
220
221
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        vision_config = hf_config.vision_config
222
        num_images = mm_counts.get("image", 0)
223
224
225
226
227
228
229
230
231
232
233
234

        if isinstance(vision_config, CLIPVisionConfig):
            data = dummy_image_for_clip(vision_config, num_images)
        elif isinstance(vision_config, SiglipVisionConfig):
            data = dummy_image_for_siglip(vision_config, num_images)
        elif isinstance(vision_config, PixtralVisionConfig):
            data = dummy_image_for_pixtral_hf(vision_config, num_images)
        else:
            msg = f"Unsupported vision config: {type(vision_config)}"
            raise NotImplementedError(msg)

        hf_processor = self._get_hf_processor()
235
        image_token = hf_processor.image_token
236

237
238
239
240
        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=data,
        )
241
242


243
244
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    vision_feature_layer: Union[int, List[int]]


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
280
281
282
283
284
285
286


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
287
    prefix: str = "",
288
):
289
290
    vision_config = hf_config.vision_config

291
292
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
293
294
295
296

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
297
            quant_config=quant_config,
298
            num_hidden_layers_override=num_hidden_layers,
299
            require_post_norm=require_post_norm,
300
            prefix=prefix,
301
302
303
304
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
305
            quant_config=quant_config,
306
            num_hidden_layers_override=num_hidden_layers,
307
            require_post_norm=require_post_norm,
308
            prefix=prefix,
309
        )
310
    elif isinstance(vision_config, PixtralVisionConfig):
311
312
        return PixtralHFVisionModel(
            vision_config,
313
            quant_config=quant_config,
314
315
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
316
            prefix=prefix,
317
        )
318
319
320
321
322

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


323
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
324
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
325
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
326
327
328
329
330
331
332
333
334
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
335

336
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
337
        super().__init__()
338

339
340
341
342
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

343
        self.config = config
344
        self.multimodal_config = multimodal_config
345

346
347
348
349
350
351
352
353
354
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

355
        # TODO: Optionally initializes this for supporting embeddings.
356
        self.vision_tower = init_vision_tower_for_llava(
357
358
359
            config,
            quant_config,
            require_post_norm=False,
360
            prefix=maybe_prefix(prefix, "vision_tower"))
361
362
363
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
364
365
366
            projector_hidden_act=config.projector_hidden_act,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
367

368
        self.language_model = init_vllm_registered_model(
369
            vllm_config=vllm_config,
370
371
372
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
373

374
375
376
377
378
379
380
381
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
382
        return get_sampler()
383

384
385
386
387
388
389
390
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
391
            raise ValueError(
392
393
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
394
395
396
397

        return data

    def _parse_and_validate_image_input(
398
399
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
400
        image_embeds = kwargs.pop("image_embeds", None)
401

402
        if pixel_values is None and image_embeds is None:
403
            return None
404

405
        if pixel_values is not None:
406
            if not isinstance(pixel_values, (torch.Tensor, list)):
407
408
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
409

410
411
            return LlavaImagePixelInputs(
                type="pixel_values",
412
413
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
414
415
416
            )

        if image_embeds is not None:
417
            if not isinstance(image_embeds, (torch.Tensor, list)):
418
419
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
420

421
422
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
423
                data=flatten_bn(image_embeds, concat=True),
424
425
426
            )

        raise AssertionError("This line should be unreachable.")
427
428
429
430
431
432
433
434
435
436
437

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

438
439
    def _image_pixels_to_features(
        self,
440
441
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
442
443
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
444

445
446
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
447
        image_features = vision_tower(pixel_values)
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
464
465
466
467

        if image_input["type"] == "image_embeds":
            return image_input["data"]

468
469
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
470
471
        return self.multi_modal_projector(image_features)

472
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
473
474
475
476
477
478
479
480
481
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
482
        multimodal_embeddings: Optional[NestedTensors] = None,
483
484
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
485
        if multimodal_embeddings is not None:
486
            inputs_embeds = merge_multimodal_embeddings(
487
                input_ids, inputs_embeds, multimodal_embeddings,
488
489
490
                self.config.image_token_index)
        return inputs_embeds

491
492
493
494
495
496
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
497
        intermediate_tensors: Optional[IntermediateTensors] = None,
498
        inputs_embeds: Optional[torch.Tensor] = None,
499
        **kwargs: object,
500
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
501
        """Run forward pass for LLaVA-1.5.
502
503
504

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
505

506
        Concretely, consider a text prompt:
507
508
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

509
        Tokenizer outputs:
510
511
512
513
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
514
        before they are inputted to the model, so the input processor prepends
515
516
517
518
519
520
521
522
523
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
524
525
526
527
528
529
530

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
531
            pixel_values: The pixels in each input image.
532

533
534
        See also:
            :class:`LlavaImageInputs`
535
        """
536
537
        if intermediate_tensors is not None:
            inputs_embeds = None
538
539
540

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
541
        elif inputs_embeds is None:
542
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
543
544
545
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
546

547
548
549
550
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
551
                                                  intermediate_tensors,
552
                                                  inputs_embeds=inputs_embeds)
553
554
555

        return hidden_states

556
557
558
559
560
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
561
562
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
563
564
565
566
567
568

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
569
        return self.language_model.sample(logits, sampling_metadata)
570

571
572
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
573
        loader = AutoWeightsLoader(self)
574
        return loader.load_weights(weights)
575
576


577
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
578

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
    def _get_hf_processor(self):
        return self.ctx.get_hf_processor(LlavaProcessor)

    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        image_token_id = hf_config.image_token_index
        max_image_tokens = get_max_llava_image_tokens(self.ctx)

        result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)

        mm_items = self._get_mm_items(mm_data)
        mm_item_counts = mm_items.get_item_counts()
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
                "<image>" * max_image_tokens,
                "</Image>)",  # 3 tokens
            ])

        mantis_repls = self._bind_prompt_replacements([
            PromptReplacement(
                modality="image",
                target=[image_token_id] * max_image_tokens,
                replacement=get_replacement_mantis,
            )
        ])

        prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
            result["prompt_token_ids"],
            mantis_repls,
            mm_item_counts,
        )

        unbound_orig_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
        orig_repls = self._bind_prompt_replacements(unbound_orig_repls)

        all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
                                                   mm_item_counts)
        assert len(all_placeholders) == mm_item_counts.get("image", 0)

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }

        return MultiModalInputsV2(
            type="multimodal",
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
644
645
646
647
648


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
649
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
650
651
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass