llava.py 29.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
49
from vllm.multimodal.profiling import BaseDummyInputsBuilder
50
from vllm.sequence import IntermediateTensors
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .clip import CLIPVisionModel
54
55
56
57
58
59
60
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .module_mapping import MultiModelKeys
61
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
62
from .siglip import SiglipVisionModel
63
64
65
66
67
68
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
69
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
70
71


72
class LlavaImagePixelInputs(TensorSchema):
73
    """
74
75
76
77
78
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
79

80
81
82
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
83

84
85
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
86

87

88
class PixtralHFImagePixelInputs(TensorSchema):
89
    """
90
91
92
93
94
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
95

96
97
98
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
99

100
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
101
    pixel_values: Annotated[
102
        torch.Tensor | list[torch.Tensor],
103
104
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
105

106

107
class LlavaImageEmbeddingInputs(TensorSchema):
108
    """
109
110
111
112
113
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
114

115
116
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
117
118


119
120
121
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
122
123


124
class LlavaMultiModalProjector(nn.Module):
125
126
127
128
129
130
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
131
        quant_config: QuantizationConfig | None = None,
132
133
        prefix: str = "",
    ):
134
135
        super().__init__()

136
137
138
139
140
141
142
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
143
        self.act = get_act_fn(projector_hidden_act)
144
145
146
147
148
149
150
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
151

152
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
153
        hidden_states, _ = self.linear_1(image_features)
154
        hidden_states = self.act(hidden_states)
155
        hidden_states, _ = self.linear_2(hidden_states)
156
157
158
        return hidden_states


159
160
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
161
    image_token_index: Final[int]
162
    vision_feature_select_strategy: Final[str]
163
    vision_feature_layer: Final[int | list[int]]
164

165

166
167
168
169
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


170
171
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
172
        return self.ctx.get_hf_config(LlavaConfig)
173

174
175
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
176

177
    @abstractmethod
178
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
179
        raise NotImplementedError
180

181
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
182
        return {"image": None}
183

184
185
186
187
188
189
190
191
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
192

193
        return get_num_selected_vision_tokens(
194
195
196
197
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
198
            hf_config.vision_feature_select_strategy,
199
        )
200

201
202
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
203
204
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
205

206
207
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
208

209
        return self.get_num_image_tokens(
210
211
212
213
            image_width=target_width,
            image_height=target_height,
        )

214
215
216
217
218

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
219
220
221
222
223
224
225
226
227
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
228
        self,
229
        seq_len: int,
230
        mm_counts: Mapping[str, int],
231
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
232
    ) -> MultiModalDataDict:
233
234
        num_images = mm_counts.get("image", 0)

235
        target_width, target_height = self.info.get_image_size_with_most_features()
236

237
238
        image_overrides = mm_options.get("image") if mm_options else None

239
        return {
240
241
242
243
244
245
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
246
247
248
        }


249
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
250
    def get_hf_processor(self, **kwargs: object):
251
252
253
254
255
256
257
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
258
259


260
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
261
262
263
264
265
266
267
268
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
269

270
    def _get_prompt_updates(
271
272
273
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
274
        out_mm_kwargs: MultiModalKwargsItems,
275
    ) -> Sequence[PromptUpdate]:
276
        hf_config = self.info.get_hf_config()
277
278
279
280
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
281
282
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
283
284
285
286
287

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
288
                num_image_tokens = self.info.get_num_image_tokens(
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


304
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
305
306
307
308
309
310
311
312
313
314
315
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


316
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
317
318
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
319

320

321
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
322
323
324
325
326
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
327
        tok_kwargs: Mapping[str, object],
328
329
330
331
332
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
333
            tok_kwargs=tok_kwargs,
334
        )
335

336
337
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
338
339
340
341
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
342

343
344
345
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
346

347
        return processed_outputs
348

349
350
351
352
353
354
355
356
357
358
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

359
    def _get_prompt_updates(
360
361
        self,
        mm_items: MultiModalDataItems,
362
        hf_processor_mm_kwargs: Mapping[str, object],
363
        out_mm_kwargs: MultiModalKwargsItems,
364
    ) -> Sequence[PromptUpdate]:
365
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
366
        hf_config = self.info.get_hf_config()
367
368
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
369

370
371
372
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
373

374
375
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
376

377
378
379
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
380

381
            ncols, nrows = encoder_info.get_patch_grid_size(
382
383
384
                image_width=image_size.width,
                image_height=image_size.height,
            )
385

386
387
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
388

389
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
390
391
392
393
394

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
395
396
                replacement=get_replacement,
            ),
397
398
        ]

399

400
def _build_llava_or_pixtral_hf_info(
401
402
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
403
404
405
406
407
408
409
410
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


411
def _build_llava_or_pixtral_hf_processor(
412
413
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
414
    *,
415
    cache: BaseMultiModalProcessorCache | None = None,
416
) -> BaseMultiModalProcessor:
417
    if isinstance(info, PixtralHFProcessingInfo):
418
        return PixtralHFMultiModalProcessor(
419
420
421
422
423
424
425
426
427
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
428
            cache=cache,
429
        )
430

431
    raise NotImplementedError(type(info))
432
433
434
435
436


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
437

438
439
440
441
442
443
444
445
446
447
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
448
449
450
451
        return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
452
453
454


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
455
    """Given a signed vision feature layer, get the number of hidden layers
456
457
458
459
460
461
462
463
464
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
465
    return feature_layer_index
466
467
468
469


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
470
    quant_config: QuantizationConfig | None,
471
    *,
472
    require_post_norm: bool | None = None,
473
    prefix: str = "",
474
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
475
476
    vision_config = hf_config.vision_config

477
478
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
479
480
481
482

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
483
            quant_config=quant_config,
484
            num_hidden_layers_override=num_hidden_layers,
485
            require_post_norm=require_post_norm,
486
            prefix=prefix,
487
488
489
490
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
491
            quant_config=quant_config,
492
            num_hidden_layers_override=num_hidden_layers,
493
            require_post_norm=require_post_norm,
494
            prefix=prefix,
495
        )
496
    elif isinstance(vision_config, PixtralVisionConfig):
497
498
        return PixtralHFVisionModel(
            vision_config,
499
            quant_config=quant_config,
500
501
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
502
            prefix=prefix,
503
        )
504
505
506
507
508

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


509
510
511
512
513
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
514
515
516
class LlavaForConditionalGeneration(
    nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
):
517
518
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
519
        "gate_up_proj": ["gate_proj", "up_proj"],
520
    }
521

522
523
524
525
526
527
528
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
529
530
        }
    )
531

532
    @classmethod
533
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
534
535
536
537
538
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

539
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
540
        super().__init__()
541

542
543
544
545
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

546
        self.config = config
547
        self.multimodal_config = multimodal_config
548

549
550
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
551
552
553
554
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
555
            config.text_config.architectures = ["MistralForCausalLM"]
556
557
558
559
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
560
561
            config.projector_hidden_act = "gelu"

562
        # TODO: Optionally initializes this for supporting embeddings.
563
564
565
566
567
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_tower = init_vision_tower_for_llava(
                config,
                quant_config,
                require_post_norm=False,
568
569
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
570
571
572
573
574
575
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
576
577
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
578
579
580
        else:
            self.vision_tower = None
            self.multi_modal_projector = None
581

582
        self.language_model = init_vllm_registered_model(
583
            vllm_config=vllm_config,
584
585
586
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
587

588
        self.make_empty_intermediate_tensors = (
589
590
            self.language_model.make_empty_intermediate_tensors
        )
591

592
    def _parse_and_validate_image_input(
593
        self, **kwargs: object
594
    ) -> LlavaImageInputs | None:
595
        pixel_values = kwargs.pop("pixel_values", None)
596
        image_embeds = kwargs.pop("image_embeds", None)
597

598
        if pixel_values is None and image_embeds is None:
599
            return None
600

601
        if pixel_values is not None:
602
            if self.config.vision_config.model_type == "pixtral":
603
604
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
605
                    pixel_values=pixel_values,
606
607
                )

608
            expected_h = expected_w = self.config.vision_config.image_size
609
610
            return LlavaImagePixelInputs(
                type="pixel_values",
611
                pixel_values=pixel_values,
612
                resolve_bindings={"h": expected_h, "w": expected_w},
613
614
615
            )

        if image_embeds is not None:
616
617
618
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

619
620
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
621
                data=image_embeds,
622
623
624
            )

        raise AssertionError("This line should be unreachable.")
625

626
627
    def _image_pixels_to_features(
        self,
628
629
630
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
631
632
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
633
634
635
636
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
637

638
639
    def _process_image_pixels(
        self,
640
641
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
642
643
        assert self.vision_tower is not None

644
        pixel_values = inputs["pixel_values"]
645
646
647

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

648
649
650
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
651
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
652
653
654
        if image_input["type"] == "image_embeds":
            return image_input["data"]

655
656
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
657

658
659
660
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

661
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
662
663
664
665
666

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

667
668
669
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

670
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
671
672
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
673
            return []
674

675
        return self._process_image_input(image_input)
676

677
678
679
680
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
681
682
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
683
        **kwargs: object,
684
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
685
        """Run forward pass for LLaVA-1.5.
686
687
688

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
689

690
        Concretely, consider a text prompt:
691
692
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

693
        Tokenizer outputs:
694
695
696
697
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
698
        before they are inputted to the model, so the input processor prepends
699
700
701
702
703
704
705
706
707
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
708
709
710
711
712
713
714

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
715
716
717
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
718

719
        Info:
samzong's avatar
samzong committed
720
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
721
        """
722
723
        if intermediate_tensors is not None:
            inputs_embeds = None
724

725
726
727
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
728
729
730

        return hidden_states

731
732
733
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
734
    ) -> torch.Tensor | None:
735
        return self.language_model.compute_logits(hidden_states)
736

737
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
738
739
740
741
742
        skip_prefixes = []
        if self.vision_tower is None and self.multi_modal_projector is None:
            skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])

        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
743
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
744

745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        # LLaVA's vision encoder outputs one token per patch without
        # spatial merging or pixel shuffle
        return num_image_tokens

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        # LLaVA's MLP projector outputs the same number of tokens
        # as it receives from the vision encoder (1:1 mapping)
        return num_vision_tokens

771

772
class MantisProcessingInfo(LlavaProcessingInfo):
773
    def get_hf_processor(self, **kwargs: object):
774
775
776
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

777
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
778
779
780
781
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
782

783
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
784
785


786
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
787
788
    def apply(
        self,
789
        prompt: str | list[int],
790
791
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
792
793
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
794
    ) -> MultiModalInputs:
795
        hf_config = self.info.get_hf_config()
796
        image_token_id = hf_config.image_token_index
797
798

        # Assume that it doesn't depend on the image size
799
        num_image_tokens = self.info.get_num_image_tokens(
800
801
802
            image_width=-1,
            image_height=-1,
        )
803

804
805
806
807
808
809
810
        result = super().apply(
            prompt,
            mm_data,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
811

812
813
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
814
        mm_kwargs = result["mm_kwargs"]
815
        mm_hashes = result["mm_hashes"]
816
817
818
819

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
820
821
822
823
824
825
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
826
            )
827
828
829
830
831
832
833
834
835
836
837

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
838

839
        prompt_ids, _ = self._apply_prompt_updates(
840
            result["prompt_token_ids"],
841
            mantis_mm_repls,
842
843
        )

844
        orig_repls = self._get_mm_prompt_updates(
845
846
847
848
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
849
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
850
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
851

852
853
854
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
855
856
        }

857
        return MultiModalInputs(
858
859
860
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
861
            mm_hashes=mm_hashes,
862
            mm_placeholders=mm_placeholder_ranges,
863
        )
864
865
866
867


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
868
869
870
871
872
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
873
874
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass