llava.py 28.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
49
from vllm.multimodal.profiling import BaseDummyInputsBuilder
50
from vllm.sequence import IntermediateTensors
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .clip import CLIPVisionModel
54
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
55
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
56
from .siglip import SiglipVisionModel
57
58
59
60
61
62
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
63
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
64
65


66
class LlavaImagePixelInputs(TensorSchema):
67
    """
68
69
70
71
72
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
73

74
75
76
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
77

78
79
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
80

81

82
class PixtralHFImagePixelInputs(TensorSchema):
83
    """
84
85
86
87
88
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
89

90
91
92
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
93

94
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
95
    pixel_values: Annotated[
96
        torch.Tensor | list[torch.Tensor],
97
98
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
99

100

101
class LlavaImageEmbeddingInputs(TensorSchema):
102
    """
103
104
105
106
107
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
108

109
110
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
111
112


113
114
115
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
116
117


118
class LlavaMultiModalProjector(nn.Module):
119
120
121
122
123
124
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
125
        quant_config: QuantizationConfig | None = None,
126
127
        prefix: str = "",
    ):
128
129
        super().__init__()

130
131
132
133
134
135
136
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
137
        self.act = get_act_fn(projector_hidden_act)
138
139
140
141
142
143
144
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
145

146
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
147
        hidden_states, _ = self.linear_1(image_features)
148
        hidden_states = self.act(hidden_states)
149
        hidden_states, _ = self.linear_2(hidden_states)
150
151
152
        return hidden_states


153
154
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
155
    image_token_index: Final[int]
156
    vision_feature_select_strategy: Final[str]
157
    vision_feature_layer: Final[int | list[int]]
158

159

160
161
162
163
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


164
165
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
166
        return self.ctx.get_hf_config(LlavaConfig)
167

168
169
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
170

171
    @abstractmethod
172
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
173
        raise NotImplementedError
174

175
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
176
        return {"image": None}
177

178
179
180
181
182
183
184
185
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
186

187
        return get_num_selected_vision_tokens(
188
189
190
191
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
192
            hf_config.vision_feature_select_strategy,
193
        )
194

195
196
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
197
198
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
199

200
201
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
202

203
        return self.get_num_image_tokens(
204
205
206
207
            image_width=target_width,
            image_height=target_height,
        )

208
209
210
211
212

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
213
214
215
216
217
218
219
220
221
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
222
        self,
223
        seq_len: int,
224
        mm_counts: Mapping[str, int],
225
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
226
    ) -> MultiModalDataDict:
227
228
        num_images = mm_counts.get("image", 0)

229
        target_width, target_height = self.info.get_image_size_with_most_features()
230

231
232
        image_overrides = mm_options.get("image") if mm_options else None

233
        return {
234
235
236
237
238
239
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
240
241
242
        }


243
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
244
    def get_hf_processor(self, **kwargs: object):
245
246
247
248
249
250
251
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
252
253


254
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
255
256
257
258
259
260
261
262
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
263

264
    def _get_prompt_updates(
265
266
267
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
268
        out_mm_kwargs: MultiModalKwargsItems,
269
    ) -> Sequence[PromptUpdate]:
270
        hf_config = self.info.get_hf_config()
271
272
273
274
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
275
276
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
277
278
279
280
281

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
282
                num_image_tokens = self.info.get_num_image_tokens(
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


298
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
299
300
301
302
303
304
305
306
307
308
309
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


310
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
311
312
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
313

314

315
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
316
317
318
319
320
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
321
        tok_kwargs: Mapping[str, object],
322
323
324
325
326
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
327
            tok_kwargs=tok_kwargs,
328
        )
329

330
331
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
332
333
334
335
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
336

337
338
339
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
340

341
        return processed_outputs
342

343
344
345
346
347
348
349
350
351
352
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

353
    def _get_prompt_updates(
354
355
        self,
        mm_items: MultiModalDataItems,
356
        hf_processor_mm_kwargs: Mapping[str, object],
357
        out_mm_kwargs: MultiModalKwargsItems,
358
    ) -> Sequence[PromptUpdate]:
359
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
360
        hf_config = self.info.get_hf_config()
361
362
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
363

364
365
366
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
367

368
369
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
370

371
372
373
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
374

375
            ncols, nrows = encoder_info.get_patch_grid_size(
376
377
378
                image_width=image_size.width,
                image_height=image_size.height,
            )
379

380
381
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
382

383
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
384
385
386
387
388

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
389
390
                replacement=get_replacement,
            ),
391
392
        ]

393

394
def _build_llava_or_pixtral_hf_info(
395
396
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
397
398
399
400
401
402
403
404
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


405
def _build_llava_or_pixtral_hf_processor(
406
407
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
408
    *,
409
    cache: BaseMultiModalProcessorCache | None = None,
410
) -> BaseMultiModalProcessor:
411
    if isinstance(info, PixtralHFProcessingInfo):
412
        return PixtralHFMultiModalProcessor(
413
414
415
416
417
418
419
420
421
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
422
            cache=cache,
423
        )
424

425
    raise NotImplementedError(type(info))
426
427
428
429
430


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
431

432
433
434
435
436
437
438
439
440
441
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
442
443
444
445
        return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
446
447
448


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
449
    """Given a signed vision feature layer, get the number of hidden layers
450
451
452
453
454
455
456
457
458
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
459
    return feature_layer_index
460
461
462
463


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
464
    quant_config: QuantizationConfig | None,
465
    *,
466
    require_post_norm: bool | None = None,
467
    prefix: str = "",
468
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
469
470
    vision_config = hf_config.vision_config

471
472
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
473
474
475
476

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
477
            quant_config=quant_config,
478
            num_hidden_layers_override=num_hidden_layers,
479
            require_post_norm=require_post_norm,
480
            prefix=prefix,
481
482
483
484
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
485
            quant_config=quant_config,
486
            num_hidden_layers_override=num_hidden_layers,
487
            require_post_norm=require_post_norm,
488
            prefix=prefix,
489
        )
490
    elif isinstance(vision_config, PixtralVisionConfig):
491
492
        return PixtralHFVisionModel(
            vision_config,
493
            quant_config=quant_config,
494
495
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
496
            prefix=prefix,
497
        )
498
499
500
501
502

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


503
504
505
506
507
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
508
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
509
510
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
511
        "gate_up_proj": ["gate_proj", "up_proj"],
512
    }
513

514
515
516
517
518
519
520
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
521
522
        }
    )
523

524
    @classmethod
525
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
526
527
528
529
530
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

531
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
532
        super().__init__()
533

534
535
536
537
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

538
        self.config = config
539
        self.multimodal_config = multimodal_config
540

541
542
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
543
544
545
546
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
547
            config.text_config.architectures = ["MistralForCausalLM"]
548
549
550
551
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
552
553
            config.projector_hidden_act = "gelu"

554
        # TODO: Optionally initializes this for supporting embeddings.
555
556
557
558
559
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_tower = init_vision_tower_for_llava(
                config,
                quant_config,
                require_post_norm=False,
560
561
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
562
563
564
565
566
567
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
568
569
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
570
571
572
        else:
            self.vision_tower = None
            self.multi_modal_projector = None
573

574
        self.language_model = init_vllm_registered_model(
575
            vllm_config=vllm_config,
576
577
578
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
579

580
        self.make_empty_intermediate_tensors = (
581
582
            self.language_model.make_empty_intermediate_tensors
        )
583

584
    def _parse_and_validate_image_input(
585
        self, **kwargs: object
586
    ) -> LlavaImageInputs | None:
587
        pixel_values = kwargs.pop("pixel_values", None)
588
        image_embeds = kwargs.pop("image_embeds", None)
589

590
        if pixel_values is None and image_embeds is None:
591
            return None
592

593
        if pixel_values is not None:
594
            if self.config.vision_config.model_type == "pixtral":
595
596
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
597
                    pixel_values=pixel_values,
598
599
                )

600
            expected_h = expected_w = self.config.vision_config.image_size
601
602
            return LlavaImagePixelInputs(
                type="pixel_values",
603
                pixel_values=pixel_values,
604
                resolve_bindings={"h": expected_h, "w": expected_w},
605
606
607
            )

        if image_embeds is not None:
608
609
610
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

611
612
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
613
                data=image_embeds,
614
615
616
            )

        raise AssertionError("This line should be unreachable.")
617

618
619
    def _image_pixels_to_features(
        self,
620
621
622
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
623
624
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
625
626
627
628
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
629

630
631
    def _process_image_pixels(
        self,
632
633
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
634
635
        assert self.vision_tower is not None

636
        pixel_values = inputs["pixel_values"]
637
638
639

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

640
641
642
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
643
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
644
645
646
        if image_input["type"] == "image_embeds":
            return image_input["data"]

647
648
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
649

650
651
652
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

653
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
654
655
656
657
658

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

659
660
661
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

662
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
663
664
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
665
            return []
666

667
        return self._process_image_input(image_input)
668

669
670
671
672
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
673
674
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
675
        **kwargs: object,
676
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
677
        """Run forward pass for LLaVA-1.5.
678
679
680

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
681

682
        Concretely, consider a text prompt:
683
684
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

685
        Tokenizer outputs:
686
687
688
689
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
690
        before they are inputted to the model, so the input processor prepends
691
692
693
694
695
696
697
698
699
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
700
701
702
703
704
705
706

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
707
708
709
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
710

711
        Info:
samzong's avatar
samzong committed
712
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
713
        """
714
715
        if intermediate_tensors is not None:
            inputs_embeds = None
716

717
718
719
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
720
721
722

        return hidden_states

723
724
725
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
726
    ) -> torch.Tensor | None:
727
        return self.language_model.compute_logits(hidden_states)
728

729
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
730
731
732
733
734
        skip_prefixes = []
        if self.vision_tower is None and self.multi_modal_projector is None:
            skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])

        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
735
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
736
737


738
class MantisProcessingInfo(LlavaProcessingInfo):
739
    def get_hf_processor(self, **kwargs: object):
740
741
742
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

743
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
744
745
746
747
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
748

749
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
750
751


752
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
753
754
    def apply(
        self,
755
        prompt: str | list[int],
756
757
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
758
759
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
760
    ) -> MultiModalInputs:
761
        hf_config = self.info.get_hf_config()
762
        image_token_id = hf_config.image_token_index
763
764

        # Assume that it doesn't depend on the image size
765
        num_image_tokens = self.info.get_num_image_tokens(
766
767
768
            image_width=-1,
            image_height=-1,
        )
769

770
771
772
773
774
775
776
        result = super().apply(
            prompt,
            mm_data,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
777

778
779
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
780
        mm_kwargs = result["mm_kwargs"]
781
        mm_hashes = result["mm_hashes"]
782
783
784
785

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
786
787
788
789
790
791
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
792
            )
793
794
795
796
797
798
799
800
801
802
803

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
804

805
        prompt_ids, _ = self._apply_prompt_updates(
806
            result["prompt_token_ids"],
807
            mantis_mm_repls,
808
809
        )

810
        orig_repls = self._get_mm_prompt_updates(
811
812
813
814
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
815
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
816
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
817

818
819
820
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
821
822
        }

823
        return MultiModalInputs(
824
825
826
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
827
            mm_hashes=mm_hashes,
828
            mm_placeholders=mm_placeholder_ranges,
829
        )
830
831
832
833


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
834
835
836
837
838
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
839
840
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass