llava.py 24.9 KB
Newer Older
1
from functools import cached_property
2
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
3
                    Tuple, TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
8
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
9
                          SiglipVisionConfig)
10
from transformers.models.llava import LlavaProcessor
11
from transformers.models.pixtral import PixtralProcessor
12
13

from vllm.attention import AttentionMetadata
14
from vllm.config import VllmConfig
15
from vllm.inputs import InputContext
16
from vllm.model_executor.layers.activation import get_act_fn
17
18
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
19
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
20
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
21
from vllm.model_executor.sampling_metadata import SamplingMetadata
22
from vllm.multimodal import MULTIMODAL_REGISTRY
23
24
25
26
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalInputsV2, MultiModalKwargs,
                                    NestedTensors)
from vllm.multimodal.parse import ImageProcessorItems
27
from vllm.multimodal.processing import (BaseMultiModalProcessor,
28
29
                                        MultiModalDataItems, ProcessorInputs,
                                        PromptReplacement,
30
                                        full_groupby_modality)
31
from vllm.sequence import IntermediateTensors
32

33
from .clip import (CLIPVisionModel, dummy_image_for_clip,
34
                   get_max_clip_image_tokens)
35
from .interfaces import SupportsMultiModal, SupportsPP
36
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
37
38
                      get_max_pixtral_hf_image_tokens,
                      get_pixtral_hf_image_feature_size)
39
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
40
                     get_max_siglip_image_tokens)
41
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
42
                    maybe_prefix, merge_multimodal_embeddings)
43
44


45
46
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
47
48
49
50
51
52
53
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
54
55
56
57
58


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
59
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
60
61
62
63
64
65
66
67

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


68
69
class LlavaMultiModalProjector(nn.Module):

70
71
72
73
74
75
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
76
77
        super().__init__()

78
79
80
81
82
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
                                             bias=True,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
83
        self.act = get_act_fn(projector_hidden_act)
84
85
86
87
88
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
89

90
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
91
        hidden_states, _ = self.linear_1(image_features)
92
        hidden_states = self.act(hidden_states)
93
        hidden_states, _ = self.linear_2(hidden_states)
94
95
96
        return hidden_states


97
98
99
100
101
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
102
103
104
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
105
106
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
107
108
109
110
111
112
113
114
115
116
117
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
118
119


120
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
121

122
123
    def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
        return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
124

125
126
127
128
129
130
131
132
133
134
135
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
136

137
138
139
140
141
        # NOTE: pixel_values=None for MLlavaProcessor
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
            images = mm_data["images"]
            assert isinstance(images, list)
142

143
144
145
146
147
148
149
            if isinstance(self._get_hf_processor(), PixtralProcessor):
                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1
                        and isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))
150

151
                processed_outputs["pixel_values"] = pixel_values[0]
152

153
        return processed_outputs
154

155
156
157
158
159
160
161
162
163
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )
164

165
166
167
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
168
169
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    ) -> list[PromptReplacement]:
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        image_token_id = hf_config.image_token_index

        processor = self._get_hf_processor()
        if isinstance(processor, PixtralProcessor):
            image_token = processor.image_token
            image_break_token = processor.image_break_token
            image_end_token = processor.image_end_token

            vision_config = hf_config.vision_config
            assert isinstance(vision_config, PixtralVisionConfig)

            def get_replacement_pixtral(item_idx: int):
184
185
186
                images = mm_items.get_items("image", ImageProcessorItems)
                image_size = images.get_image_size(item_idx)

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
                (
                    num_width_tokens,
                    num_height_tokens,
                ) = get_pixtral_hf_image_feature_size(
                    vision_config,
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

                tokens = ([image_token] * num_width_tokens +
                          [image_break_token]) * num_height_tokens
                tokens[-1] = image_end_token

                return "".join(tokens)

            return [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id],
                    replacement=get_replacement_pixtral,
                ),
            ]

        max_image_tokens = get_max_llava_image_tokens(self.ctx)

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=[image_token_id] * max_image_tokens,
            )
        ]

    def _get_dummy_mm_inputs(
221
222
        self,
        mm_counts: Mapping[str, int],
223
    ) -> ProcessorInputs:
224
225
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        vision_config = hf_config.vision_config
226
        num_images = mm_counts.get("image", 0)
227
228
229
230
231
232
233
234
235
236
237
238

        if isinstance(vision_config, CLIPVisionConfig):
            data = dummy_image_for_clip(vision_config, num_images)
        elif isinstance(vision_config, SiglipVisionConfig):
            data = dummy_image_for_siglip(vision_config, num_images)
        elif isinstance(vision_config, PixtralVisionConfig):
            data = dummy_image_for_pixtral_hf(vision_config, num_images)
        else:
            msg = f"Unsupported vision config: {type(vision_config)}"
            raise NotImplementedError(msg)

        hf_processor = self._get_hf_processor()
239
        image_token = hf_processor.image_token
240

241
242
243
244
        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=data,
        )
245
246


247
248
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    vision_feature_layer: Union[int, List[int]]


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
284
285
286
287
288
289
290


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
291
    prefix: str = "",
292
):
293
294
    vision_config = hf_config.vision_config

295
296
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
297
298
299
300

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
301
            quant_config=quant_config,
302
            num_hidden_layers_override=num_hidden_layers,
303
            require_post_norm=require_post_norm,
304
            prefix=prefix,
305
306
307
308
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
309
            quant_config=quant_config,
310
            num_hidden_layers_override=num_hidden_layers,
311
            require_post_norm=require_post_norm,
312
            prefix=prefix,
313
        )
314
    elif isinstance(vision_config, PixtralVisionConfig):
315
316
        return PixtralHFVisionModel(
            vision_config,
317
            quant_config=quant_config,
318
319
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
320
            prefix=prefix,
321
        )
322
323
324
325
326

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


327
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
328
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
329
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
330
331
332
333
334
335
336
337
338
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
339

340
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
341
        super().__init__()
342

343
344
345
346
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

347
        self.config = config
348
        self.multimodal_config = multimodal_config
349

350
351
352
353
354
355
356
357
358
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

359
        # TODO: Optionally initializes this for supporting embeddings.
360
        self.vision_tower = init_vision_tower_for_llava(
361
362
363
            config,
            quant_config,
            require_post_norm=False,
364
            prefix=maybe_prefix(prefix, "vision_tower"))
365
366
367
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
368
369
370
            projector_hidden_act=config.projector_hidden_act,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
371

372
        self.language_model = init_vllm_registered_model(
373
            vllm_config=vllm_config,
374
375
376
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
377

378
379
380
381
382
383
384
385
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
386
        return get_sampler()
387

388
389
390
391
392
393
394
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
395
            raise ValueError(
396
397
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
398
399
400
401

        return data

    def _parse_and_validate_image_input(
402
403
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
404
        image_embeds = kwargs.pop("image_embeds", None)
405

406
        if pixel_values is None and image_embeds is None:
407
            return None
408

409
        if pixel_values is not None:
410
            if not isinstance(pixel_values, (torch.Tensor, list)):
411
412
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
413

414
415
            return LlavaImagePixelInputs(
                type="pixel_values",
416
417
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
418
419
420
            )

        if image_embeds is not None:
421
            if not isinstance(image_embeds, (torch.Tensor, list)):
422
423
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
424

425
426
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
427
                data=flatten_bn(image_embeds, concat=True),
428
429
430
            )

        raise AssertionError("This line should be unreachable.")
431
432
433
434
435
436
437
438
439
440
441

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

442
443
    def _image_pixels_to_features(
        self,
444
445
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
446
447
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
448

449
450
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
451
        image_features = vision_tower(pixel_values)
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
468
469
470
471

        if image_input["type"] == "image_embeds":
            return image_input["data"]

472
473
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
474
475
        return self.multi_modal_projector(image_features)

476
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
477
478
479
480
481
482
483
484
485
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
486
        multimodal_embeddings: Optional[NestedTensors] = None,
487
488
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
489
        if multimodal_embeddings is not None:
490
            inputs_embeds = merge_multimodal_embeddings(
491
                input_ids, inputs_embeds, multimodal_embeddings,
492
493
494
                self.config.image_token_index)
        return inputs_embeds

495
496
497
498
499
500
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
501
        intermediate_tensors: Optional[IntermediateTensors] = None,
502
        inputs_embeds: Optional[torch.Tensor] = None,
503
        **kwargs: object,
504
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
505
        """Run forward pass for LLaVA-1.5.
506
507
508

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
509

510
        Concretely, consider a text prompt:
511
512
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

513
        Tokenizer outputs:
514
515
516
517
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
518
        before they are inputted to the model, so the input processor prepends
519
520
521
522
523
524
525
526
527
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
528
529
530
531
532
533
534

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
535
            pixel_values: The pixels in each input image.
536

537
538
        See also:
            :class:`LlavaImageInputs`
539
        """
540
541
        if intermediate_tensors is not None:
            inputs_embeds = None
542
543
544

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
545
        elif inputs_embeds is None:
546
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
547
548
549
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
550

551
552
553
554
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
555
                                                  intermediate_tensors,
556
                                                  inputs_embeds=inputs_embeds)
557
558
559

        return hidden_states

560
561
562
563
564
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
565
566
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
567
568
569
570
571
572

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
573
        return self.language_model.sample(logits, sampling_metadata)
574

575
576
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
577
        loader = AutoWeightsLoader(self)
578
        return loader.load_weights(weights)
579
580


581
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
582

583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
    def _get_hf_processor(self):
        return self.ctx.get_hf_processor(LlavaProcessor)

    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        image_token_id = hf_config.image_token_index
        max_image_tokens = get_max_llava_image_tokens(self.ctx)

        result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)

598
599
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
                "<image>" * max_image_tokens,
                "</Image>)",  # 3 tokens
            ])

        mantis_repls = self._bind_prompt_replacements([
            PromptReplacement(
                modality="image",
                target=[image_token_id] * max_image_tokens,
                replacement=get_replacement_mantis,
            )
        ])

        prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
            result["prompt_token_ids"],
            mantis_repls,
            mm_item_counts,
        )

        unbound_orig_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
        orig_repls = self._bind_prompt_replacements(unbound_orig_repls)

        all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
                                                   mm_item_counts)
        assert len(all_placeholders) == mm_item_counts.get("image", 0)

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }

        return MultiModalInputsV2(
            type="multimodal",
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
648
649
650
651
652


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
653
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
654
655
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass