llava.py 28.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
49
from vllm.multimodal.profiling import BaseDummyInputsBuilder
50
from vllm.sequence import IntermediateTensors
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .clip import CLIPVisionModel
54
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
55
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
56
from .siglip import SiglipVisionModel
57
58
59
60
61
62
63
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
64
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
65
66


67
class LlavaImagePixelInputs(TensorSchema):
68
    """
69
70
71
72
73
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
74

75
76
77
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
78

79
80
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
81

82

83
class PixtralHFImagePixelInputs(TensorSchema):
84
    """
85
86
87
88
89
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
90

91
92
93
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
94

95
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
96
97
    pixel_values: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
98
99
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
100

101

102
class LlavaImageEmbeddingInputs(TensorSchema):
103
    """
104
105
106
107
108
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
109

110
111
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
112
113


114
115
116
LlavaImageInputs = Union[
    LlavaImagePixelInputs, PixtralHFImagePixelInputs, LlavaImageEmbeddingInputs
]
117
118


119
class LlavaMultiModalProjector(nn.Module):
120
121
122
123
124
125
126
127
128
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
129
130
        super().__init__()

131
132
133
134
135
136
137
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
138
        self.act = get_act_fn(projector_hidden_act)
139
140
141
142
143
144
145
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
146

147
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
148
        hidden_states, _ = self.linear_1(image_features)
149
        hidden_states = self.act(hidden_states)
150
        hidden_states, _ = self.linear_2(hidden_states)
151
152
153
        return hidden_states


154
155
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
156
    image_token_index: Final[int]
157
    vision_feature_select_strategy: Final[str]
158
    vision_feature_layer: Final[Union[int, list[int]]]
159

160

161
162
163
164
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


165
166
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
167
        return self.ctx.get_hf_config(LlavaConfig)
168

169
170
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
171

172
    @abstractmethod
173
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
174
        raise NotImplementedError
175

176
177
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
178

179
180
181
182
183
184
185
186
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
187

188
        return get_num_selected_vision_tokens(
189
190
191
192
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
193
            hf_config.vision_feature_select_strategy,
194
        )
195

196
197
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
198
199
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
200

201
202
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
203

204
        return self.get_num_image_tokens(
205
206
207
208
            image_width=target_width,
            image_height=target_height,
        )

209
210
211
212
213

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
214
215
216
217
218
219
220
221
222
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
223
        self,
224
        seq_len: int,
225
        mm_counts: Mapping[str, int],
226
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
227
    ) -> MultiModalDataDict:
228
229
        num_images = mm_counts.get("image", 0)

230
        target_width, target_height = self.info.get_image_size_with_most_features()
231

232
233
        image_overrides = mm_options.get("image") if mm_options else None

234
        return {
235
236
237
238
239
240
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
241
242
243
        }


244
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
245
    def get_hf_processor(self, **kwargs: object):
246
247
248
249
250
251
252
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
253
254


255
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
256
257
258
259
260
261
262
263
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
264

265
    def _get_prompt_updates(
266
267
268
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
269
        out_mm_kwargs: MultiModalKwargsItems,
270
    ) -> Sequence[PromptUpdate]:
271
        hf_config = self.info.get_hf_config()
272
273
274
275
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
276
277
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
278
279
280
281
282

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
283
                num_image_tokens = self.info.get_num_image_tokens(
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


299
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
300
301
302
303
304
305
306
307
308
309
310
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


311
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
312
313
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
314

315

316
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
317
318
319
320
321
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
322
        tok_kwargs: Mapping[str, object],
323
324
325
326
327
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
328
            tok_kwargs=tok_kwargs,
329
        )
330

331
332
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
333
334
335
336
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
337

338
339
340
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
341

342
        return processed_outputs
343

344
345
346
347
348
349
350
351
352
353
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

354
    def _get_prompt_updates(
355
356
        self,
        mm_items: MultiModalDataItems,
357
        hf_processor_mm_kwargs: Mapping[str, object],
358
        out_mm_kwargs: MultiModalKwargsItems,
359
    ) -> Sequence[PromptUpdate]:
360
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
361
        hf_config = self.info.get_hf_config()
362
363
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
364

365
366
367
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
368

369
370
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
371

372
373
374
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
375

376
            ncols, nrows = encoder_info.get_patch_grid_size(
377
378
379
                image_width=image_size.width,
                image_height=image_size.height,
            )
380

381
382
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
383

384
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
385
386
387
388
389

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
390
391
                replacement=get_replacement,
            ),
392
393
        ]

394

395
def _build_llava_or_pixtral_hf_info(
396
397
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
398
399
400
401
402
403
404
405
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


406
def _build_llava_or_pixtral_hf_processor(
407
408
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
409
    *,
410
    cache: Optional[BaseMultiModalProcessorCache] = None,
411
) -> BaseMultiModalProcessor:
412
    if isinstance(info, PixtralHFProcessingInfo):
413
        return PixtralHFMultiModalProcessor(
414
415
416
417
418
419
420
421
422
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
423
            cache=cache,
424
        )
425

426
    raise NotImplementedError(type(info))
427
428
429
430
431


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
432

433
434
435
436
437
438
439
440
441
442
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
443
444
445
446
        return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
447
448
449


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
450
    """Given a signed vision feature layer, get the number of hidden layers
451
452
453
454
455
456
457
458
459
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
460
    return feature_layer_index
461
462
463
464
465
466
467


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
468
    prefix: str = "",
469
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
470
471
    vision_config = hf_config.vision_config

472
473
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
474
475
476
477

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
478
            quant_config=quant_config,
479
            num_hidden_layers_override=num_hidden_layers,
480
            require_post_norm=require_post_norm,
481
            prefix=prefix,
482
483
484
485
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
486
            quant_config=quant_config,
487
            num_hidden_layers_override=num_hidden_layers,
488
            require_post_norm=require_post_norm,
489
            prefix=prefix,
490
        )
491
    elif isinstance(vision_config, PixtralVisionConfig):
492
493
        return PixtralHFVisionModel(
            vision_config,
494
            quant_config=quant_config,
495
496
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
497
            prefix=prefix,
498
        )
499
500
501
502
503

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


504
505
506
507
508
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
509
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
510
511
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
512
        "gate_up_proj": ["gate_proj", "up_proj"],
513
    }
514

515
516
517
518
519
520
521
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
522
523
        }
    )
524

525
526
527
528
529
530
531
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

532
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
533
        super().__init__()
534

535
536
537
538
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

539
        self.config = config
540
        self.multimodal_config = multimodal_config
541

542
543
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
544
545
546
547
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
548
            config.text_config.architectures = ["MistralForCausalLM"]
549
550
551
552
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
553
554
            config.projector_hidden_act = "gelu"

555
        # TODO: Optionally initializes this for supporting embeddings.
556
557
558
559
560
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_tower = init_vision_tower_for_llava(
                config,
                quant_config,
                require_post_norm=False,
561
562
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
563
564
565
566
567
568
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
569
570
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
571
572
573
        else:
            self.vision_tower = None
            self.multi_modal_projector = None
574

575
        self.language_model = init_vllm_registered_model(
576
            vllm_config=vllm_config,
577
578
579
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
580

581
        self.make_empty_intermediate_tensors = (
582
583
            self.language_model.make_empty_intermediate_tensors
        )
584

585
    def _parse_and_validate_image_input(
586
587
        self, **kwargs: object
    ) -> Optional[LlavaImageInputs]:
588
        pixel_values = kwargs.pop("pixel_values", None)
589
        image_embeds = kwargs.pop("image_embeds", None)
590

591
        if pixel_values is None and image_embeds is None:
592
            return None
593

594
        if pixel_values is not None:
595
            if not isinstance(pixel_values, (torch.Tensor, list)):
596
597
598
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )
599

600
            if self.config.vision_config.model_type == "pixtral":
601
602
603
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
604
605
                )

606
            expected_h = expected_w = self.config.vision_config.image_size
607
608
            return LlavaImagePixelInputs(
                type="pixel_values",
609
                pixel_values=flatten_bn(pixel_values, concat=True),
610
                resolve_bindings={"h": expected_h, "w": expected_w},
611
612
613
            )

        if image_embeds is not None:
614
            if not isinstance(image_embeds, (torch.Tensor, list)):
615
616
617
618
                raise ValueError(
                    "Incorrect type of image embeddings. "
                    f"Got type: {type(image_embeds)}"
                )
619

620
621
622
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

623
624
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
625
                data=flatten_bn(image_embeds, concat=True),
626
627
628
            )

        raise AssertionError("This line should be unreachable.")
629

630
631
    def _image_pixels_to_features(
        self,
632
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel],
633
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
634
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
635
636
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
637
638
639
640
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
641

642
643
644
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
645
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
646
647
        assert self.vision_tower is not None

648
        pixel_values = inputs["pixel_values"]
649
650
651

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

652
653
654
655
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
656
657
658
        if image_input["type"] == "image_embeds":
            return image_input["data"]

659
660
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
661

662
663
664
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

665
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
666
667
668
669
670

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

671
672
673
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

674
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
675
676
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
677
            return []
678

679
        return self._process_image_input(image_input)
680

681
682
683
684
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
685
        intermediate_tensors: Optional[IntermediateTensors] = None,
686
        inputs_embeds: Optional[torch.Tensor] = None,
687
        **kwargs: object,
688
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
689
        """Run forward pass for LLaVA-1.5.
690
691
692

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
693

694
        Concretely, consider a text prompt:
695
696
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

697
        Tokenizer outputs:
698
699
700
701
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
702
        before they are inputted to the model, so the input processor prepends
703
704
705
706
707
708
709
710
711
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
712
713
714
715
716
717
718

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
719
720
721
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
722

723
        Info:
samzong's avatar
samzong committed
724
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
725
        """
726
727
        if intermediate_tensors is not None:
            inputs_embeds = None
728

729
730
731
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
732
733
734

        return hidden_states

735
736
737
738
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
739
        return self.language_model.compute_logits(hidden_states)
740

741
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
742
743
744
745
746
        skip_prefixes = []
        if self.vision_tower is None and self.multi_modal_projector is None:
            skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])

        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
747
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
748
749


750
class MantisProcessingInfo(LlavaProcessingInfo):
751
    def get_hf_processor(self, **kwargs: object):
752
753
754
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

755
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
756
757
758
759
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
760

761
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
762
763


764
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
765
766
    def apply(
        self,
767
        prompt: Union[str, list[int]],
768
769
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
770
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
771
        mm_uuids: Optional[MultiModalUUIDDict] = None,
772
    ) -> MultiModalInputs:
773
        hf_config = self.info.get_hf_config()
774
        image_token_id = hf_config.image_token_index
775
776

        # Assume that it doesn't depend on the image size
777
        num_image_tokens = self.info.get_num_image_tokens(
778
779
780
            image_width=-1,
            image_height=-1,
        )
781

782
783
784
785
786
787
788
        result = super().apply(
            prompt,
            mm_data,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
789

790
791
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
792
        mm_kwargs = result["mm_kwargs"]
793
        mm_hashes = result["mm_hashes"]
794
795
796
797

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
798
799
800
801
802
803
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
804
            )
805
806
807
808
809
810
811
812
813
814
815

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
816

817
        prompt_ids, _ = self._apply_prompt_updates(
818
            result["prompt_token_ids"],
819
            mantis_mm_repls,
820
821
        )

822
        orig_repls = self._get_mm_prompt_updates(
823
824
825
826
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
827
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
828
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
829

830
831
832
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
833
834
        }

835
        return MultiModalInputs(
836
837
838
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
839
            mm_hashes=mm_hashes,
840
            mm_placeholders=mm_placeholder_ranges,
841
        )
842
843
844
845


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
846
847
848
849
850
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
851
852
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass