llava.py 26.8 KB
Newer Older
1
from abc import abstractmethod
2
from functools import cached_property
3
4
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
                    Protocol, Set, Tuple, TypedDict, Union)
5
6

import torch
7
import torch.nn as nn
8
9
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
10
                          SiglipVisionConfig)
11
from transformers.models.llava import LlavaProcessor
12
from transformers.models.pixtral import PixtralProcessor
13
14

from vllm.attention import AttentionMetadata
15
from vllm.config import VllmConfig
16
from vllm.model_executor.layers.activation import get_act_fn
17
18
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
19
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
20
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
21
from vllm.model_executor.sampling_metadata import SamplingMetadata
22
from vllm.multimodal import MULTIMODAL_REGISTRY
23
24
25
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalInputsV2, MultiModalKwargs,
                                    NestedTensors)
26
27
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   ImageSize)
28
from vllm.multimodal.processing import (InputProcessingContext,
29
                                        MultiModalDataItems, ProcessingCache,
30
                                        ProcessorInputs, PromptReplacement)
31
from vllm.sequence import IntermediateTensors
32

33
from .clip import CLIPVisionModel
34
from .interfaces import SupportsMultiModal, SupportsPP
35
36
37
from .pixtral import (PixtralHFVisionModel,
                      get_pixtral_hf_image_feature_grid_size)
from .siglip import SiglipVisionModel
38
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
39
                    maybe_prefix, merge_multimodal_embeddings)
40
from .vision import BaseVisionLanguageMultiModalProcessor
41
42


43
44
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
45
46
47
48
49
50
51
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
52
53
54
55
56


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
57
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
58
59
60
61
62
63
64
65

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


66
67
class LlavaMultiModalProjector(nn.Module):

68
69
70
71
72
73
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
74
75
        super().__init__()

76
77
78
79
80
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
                                             bias=True,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
81
        self.act = get_act_fn(projector_hidden_act)
82
83
84
85
86
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
87

88
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
89
        hidden_states, _ = self.linear_1(image_features)
90
        hidden_states = self.act(hidden_states)
91
        hidden_states, _ = self.linear_2(hidden_states)
92
93
94
        return hidden_states


95
96
97
98
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
    vision_feature_select_strategy: Final[str]
    vision_feature_layer: Final[Union[int, List[int]]]
99

100

101
class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
102

103
104
105
    @abstractmethod
    def _get_hf_config(self) -> LlavaLikeConfig:
        raise NotImplementedError
106

107
108
109
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

110
111
112
113
114
115
116
117
118
119
120
121
122
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        return {"image": self._get_max_image_tokens()}

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

    def _get_max_image_tokens(self) -> int:
        hf_config = self._get_hf_config()

        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            self._vision_encoder_info.get_max_image_tokens(),
        )

    def _get_dummy_image_size(self) -> ImageSize:
        image_size = self._vision_encoder_info.get_image_size()
        return ImageSize(image_size, image_size)
147

148
149
150
151
    @abstractmethod
    def _get_image_token(self) -> str:
        raise NotImplementedError

152
    def _get_dummy_processor_inputs(
153
        self,
154
        seq_len: int,
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

        image_token = self._get_image_token()
        target_width, target_height = self._get_dummy_image_size()

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):

    def _get_hf_config(self) -> LlavaConfig:
        return self.ctx.get_hf_config(LlavaConfig)

    def _get_hf_processor(self) -> LlavaProcessor:
        return self.ctx.get_hf_processor(LlavaProcessor)

    def _get_image_token(self) -> str:
        return self._get_hf_processor().image_token

    def _get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self._get_hf_config()

        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            self._vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
        hf_config = self._get_hf_config()
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self._get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):

    def _get_hf_config(self) -> LlavaConfig:
        return self.ctx.get_hf_config(LlavaConfig)

    def _get_hf_processor(self) -> PixtralProcessor:
        return self.ctx.get_hf_processor(PixtralProcessor)

    def _get_image_token(self) -> str:
        return self._get_hf_processor().image_token
245

246
247
248
249
250
251
252
253
254
255
256
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
257

258
259
260
261
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
            images = mm_data["images"]
            assert isinstance(images, list)
262

263
264
265
266
267
            # Original output: (1, num_images, C, H, W)
            # New output: (num_images, C, H, W)
            assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
            assert (isinstance(pixel_values[0], list)
                    and len(pixel_values[0]) == len(images))
268

269
            processed_outputs["pixel_values"] = pixel_values[0]
270

271
        return processed_outputs
272

273
274
275
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
276
277
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
278
    ) -> list[PromptReplacement]:
279
        hf_config = self._get_hf_config()
280
281
282
        image_token_id = hf_config.image_token_index

        processor = self._get_hf_processor()
283
284
285
        image_token = processor.image_token
        image_break_token = processor.image_break_token
        image_end_token = processor.image_end_token
286

287
288
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
289

290
291
292
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
293

294
295
296
297
298
            ncols, nrows = get_pixtral_hf_image_feature_grid_size(
                vision_config,
                image_width=image_size.width,
                image_height=image_size.height,
            )
299

300
301
            tokens = ([image_token] * ncols + [image_break_token]) * nrows
            tokens[-1] = image_end_token
302

303
            return "".join(tokens)
304
305
306
307
308

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
309
310
                replacement=get_replacement,
            ),
311
312
        ]

313

314
315
316
317
318
319
320
def _build_llava_or_pixtral_hf_processor(
    ctx: InputProcessingContext,
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
) -> BaseLlavaMultiModalProcessor:
    hf_config = ctx.get_hf_config(LlavaConfig)
321

322
323
324
325
326
    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFMultiModalProcessor(
            ctx,
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
327
        )
328

329
330
331
332
333
    return LlavaMultiModalProcessor(
        ctx,
        cache=cache,
        enable_sanity_checks=enable_sanity_checks,
    )
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
368
369
370
371
372
373
374


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
375
    prefix: str = "",
376
):
377
378
    vision_config = hf_config.vision_config

379
380
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
381
382
383
384

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
385
            quant_config=quant_config,
386
            num_hidden_layers_override=num_hidden_layers,
387
            require_post_norm=require_post_norm,
388
            prefix=prefix,
389
390
391
392
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
393
            quant_config=quant_config,
394
            num_hidden_layers_override=num_hidden_layers,
395
            require_post_norm=require_post_norm,
396
            prefix=prefix,
397
        )
398
    elif isinstance(vision_config, PixtralVisionConfig):
399
400
        return PixtralHFVisionModel(
            vision_config,
401
            quant_config=quant_config,
402
403
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
404
            prefix=prefix,
405
        )
406
407
408
409
410

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


411
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
412
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
413
414
415
416
417
418
419
420
421
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
422

423
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
424
        super().__init__()
425

426
427
428
429
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

430
        self.config = config
431
        self.multimodal_config = multimodal_config
432

433
434
435
436
437
438
439
440
441
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

442
        # TODO: Optionally initializes this for supporting embeddings.
443
        self.vision_tower = init_vision_tower_for_llava(
444
445
446
            config,
            quant_config,
            require_post_norm=False,
447
            prefix=maybe_prefix(prefix, "vision_tower"))
448
449
450
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
451
452
453
            projector_hidden_act=config.projector_hidden_act,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
454

455
        self.language_model = init_vllm_registered_model(
456
            vllm_config=vllm_config,
457
458
459
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
460

461
462
463
464
465
466
467
468
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
469
        return get_sampler()
470

471
472
473
474
475
476
477
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
478
            raise ValueError(
479
480
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
481
482
483
484

        return data

    def _parse_and_validate_image_input(
485
486
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
487
        image_embeds = kwargs.pop("image_embeds", None)
488

489
        if pixel_values is None and image_embeds is None:
490
            return None
491

492
        if pixel_values is not None:
493
            if not isinstance(pixel_values, (torch.Tensor, list)):
494
495
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
496

497
498
            return LlavaImagePixelInputs(
                type="pixel_values",
499
500
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
501
502
503
            )

        if image_embeds is not None:
504
            if not isinstance(image_embeds, (torch.Tensor, list)):
505
506
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
507

508
509
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
510
                data=flatten_bn(image_embeds, concat=True),
511
512
513
            )

        raise AssertionError("This line should be unreachable.")
514
515
516
517
518
519
520
521
522
523
524

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

525
526
    def _image_pixels_to_features(
        self,
527
528
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
529
530
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
531

532
533
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
534
        image_features = vision_tower(pixel_values)
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
551
552
553
554

        if image_input["type"] == "image_embeds":
            return image_input["data"]

555
556
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
557
558
        return self.multi_modal_projector(image_features)

559
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
560
561
562
563
564
565
566
567
568
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
569
        multimodal_embeddings: Optional[NestedTensors] = None,
570
571
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
572
        if multimodal_embeddings is not None:
573
            inputs_embeds = merge_multimodal_embeddings(
574
                input_ids, inputs_embeds, multimodal_embeddings,
575
576
577
                self.config.image_token_index)
        return inputs_embeds

578
579
580
581
582
583
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
584
        intermediate_tensors: Optional[IntermediateTensors] = None,
585
        inputs_embeds: Optional[torch.Tensor] = None,
586
        **kwargs: object,
587
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
588
        """Run forward pass for LLaVA-1.5.
589
590
591

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
592

593
        Concretely, consider a text prompt:
594
595
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

596
        Tokenizer outputs:
597
598
599
600
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
601
        before they are inputted to the model, so the input processor prepends
602
603
604
605
606
607
608
609
610
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
611
612
613
614
615
616
617

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
618
            pixel_values: The pixels in each input image.
619

620
621
        See also:
            :class:`LlavaImageInputs`
622
        """
623
624
        if intermediate_tensors is not None:
            inputs_embeds = None
625
626
627

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
628
        elif inputs_embeds is None:
629
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
630
631
632
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
633

634
635
636
637
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
638
                                                  intermediate_tensors,
639
                                                  inputs_embeds=inputs_embeds)
640
641
642

        return hidden_states

643
644
645
646
647
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
648
649
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
650
651
652
653
654
655

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
656
        return self.language_model.sample(logits, sampling_metadata)
657

658
659
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
660
        loader = AutoWeightsLoader(self)
661
        return loader.load_weights(weights)
662
663


664
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
665

666
667
668
669
670
671
672
673
674
675
676
    def _get_hf_processor(self):
        return self.ctx.get_hf_processor(LlavaProcessor)

    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        image_token_id = hf_config.image_token_index
677
678
679
680
681
682

        # Assume that it doesn't depend on the image size
        num_image_tokens = self._get_num_image_tokens(
            image_width=-1,
            image_height=-1,
        )
683
684
685

        result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)

686
687
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
688
689
690
691
692
693
694
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
695
                "<image>" * num_image_tokens,
696
697
698
                "</Image>)",  # 3 tokens
            ])

699
        mantis_mm_repls = self._bind_and_group_repls([
700
701
            PromptReplacement(
                modality="image",
702
                target=[image_token_id] * num_image_tokens,
703
704
705
706
707
708
                replacement=get_replacement_mantis,
            )
        ])

        prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
            result["prompt_token_ids"],
709
            mantis_mm_repls,
710
711
712
713
714
715
716
717
            mm_item_counts,
        )

        unbound_orig_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
718
719
720
721
722
723
724
        orig_repls = self._bind_and_group_repls(unbound_orig_repls)

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
725

726
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
727

728
729
730
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
731
732
733
734
735
736
737
        }

        return MultiModalInputsV2(
            type="multimodal",
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
738
            mm_placeholders=mm_placeholder_ranges,
739
        )
740
741
742
743


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
744
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
745
746
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass