llava.py 32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
7
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
                    Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.jsontree import json_map_leaves
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
29
                                    MultiModalInputs, MultiModalKwargs)
30
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
31
                                   ImageSize, MultiModalDataItems)
32
from vllm.multimodal.processing import (BaseMultiModalProcessor,
33
                                        BaseProcessingInfo, ProcessingCache,
34
35
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder
37
from vllm.sequence import IntermediateTensors
38

39
from .clip import CLIPVisionModel
40
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
41
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
42
from .siglip import SiglipVisionModel
43
44
45
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
46
from .vision import get_vision_encoder_info
47
48


49
50
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
51
    pixel_values: torch.Tensor
52
53
54
55
56
57
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
58

59
60
61
62
63
64
65
66
67
68
69

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

70
71
72
73

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
74
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
75
76
77
78
79

    `hidden_size` must match the hidden size of language model backbone.
    """


80
81
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
82
83


84
85
class LlavaMultiModalProjector(nn.Module):

86
87
88
89
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
90
                 multimodal_projector_bias: bool,
91
92
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
93
94
        super().__init__()

95
96
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
97
                                             bias=multimodal_projector_bias,
98
99
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
100
        self.act = get_act_fn(projector_hidden_act)
101
102
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
103
                                          bias=multimodal_projector_bias,
104
105
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
106

107
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
108
        hidden_states, _ = self.linear_1(image_features)
109
        hidden_states = self.act(hidden_states)
110
        hidden_states, _ = self.linear_2(hidden_states)
111
112
113
        return hidden_states


114
115
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
116
    image_token_index: Final[int]
117
    vision_feature_select_strategy: Final[str]
118
    vision_feature_layer: Final[Union[int, list[int]]]
119

120

121
122
123
124
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


125
class BaseLlavaProcessingInfo(BaseProcessingInfo):
126

127
    def get_hf_config(self) -> LlavaLikeConfig:
128
        return self.ctx.get_hf_config(LlavaConfig)
129

130
131
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
132

133
    @abstractmethod
134
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
135
        raise NotImplementedError
136

137
138
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
139

140
141
142
143
144
145
146
147
148
149
150
151
152
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

153
154
155
156
157
158
159
160
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
161

162
163
164
165
166
167
168
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
169

170
171
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
172
173
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
174

175
176
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
177

178
        return self.get_num_image_tokens(
179
180
181
182
            image_width=target_width,
            image_height=target_height,
        )

183
184
185
186
187
188

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

189
190
191
192
193
194
195
196
197
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
198
        self,
199
        seq_len: int,
200
        mm_counts: Mapping[str, int],
201
    ) -> MultiModalDataDict:
202
203
        num_images = mm_counts.get("image", 0)

204
205
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
206

207
        return {
208
209
210
211
212
213
214
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


215
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
216

217
    def get_hf_processor(self, **kwargs: object):
218
219
220
221
222
223
224
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
225
226


227
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
228
229
230
231
232
233
234
235
236

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
237

238
    def _get_prompt_updates(
239
240
241
242
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
243
    ) -> Sequence[PromptUpdate]:
244
        hf_config = self.info.get_hf_config()
245
246
247
248
249
250
251
252
253
254
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
255
                num_image_tokens = self.info.get_num_image_tokens(
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


271
272
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
273

274
275
276
277
278
279
280
281
282
283
284
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


285
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
286

287
288
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
289

290

291
292
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
293

294
295
296
297
298
299
300
301
302
303
304
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
305

306
307
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
308
            # Before/after https://github.com/huggingface/transformers/pull/35122
309
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
331

332
        return processed_outputs
333

334
335
336
337
338
339
340
341
342
343
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

344
    def _get_prompt_updates(
345
346
        self,
        mm_items: MultiModalDataItems,
347
348
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
349
    ) -> Sequence[PromptUpdate]:
350
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
351
        hf_config = self.info.get_hf_config()
352
353
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
354

355
356
357
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
358

359
360
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
361

362
363
364
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
365

366
            ncols, nrows = encoder_info.get_patch_grid_size(
367
368
369
                image_width=image_size.width,
                image_height=image_size.height,
            )
370

371
372
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
373

374
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
375
376
377
378
379

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
380
381
                replacement=get_replacement,
            ),
382
383
        ]

384

385
386
387
388
389
390
391
392
393
394
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


395
def _build_llava_or_pixtral_hf_processor(
396
397
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
398
399
    *,
    cache: Optional[ProcessingCache] = None,
400
) -> BaseMultiModalProcessor:
401
    if isinstance(info, PixtralHFProcessingInfo):
402
        return PixtralHFMultiModalProcessor(
403
404
405
406
407
408
409
410
411
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
412
            cache=cache,
413
        )
414

415
    raise NotImplementedError(type(info))
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
439
    """Given a signed vision feature layer, get the number of hidden layers
440
441
442
443
444
445
446
447
448
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
449
    return feature_layer_index
450
451
452
453
454
455
456


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
457
    prefix: str = "",
458
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
459
460
    vision_config = hf_config.vision_config

461
462
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
463
464
465
466

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
467
            quant_config=quant_config,
468
            num_hidden_layers_override=num_hidden_layers,
469
            require_post_norm=require_post_norm,
470
            prefix=prefix,
471
472
473
474
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
475
            quant_config=quant_config,
476
            num_hidden_layers_override=num_hidden_layers,
477
            require_post_norm=require_post_norm,
478
            prefix=prefix,
479
        )
480
    elif isinstance(vision_config, PixtralVisionConfig):
481
482
        return PixtralHFVisionModel(
            vision_config,
483
            quant_config=quant_config,
484
485
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
486
            prefix=prefix,
487
        )
488
489
490
491
492

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


493
494
495
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
496
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
497
498
499
500

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
501
    }
502

503
504
505
506
507
508
509
510
511
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

512
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
513
        super().__init__()
514

515
516
517
518
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

519
        self.config = config
520
        self.multimodal_config = multimodal_config
521

522
523
524
525
526
527
528
529
530
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

531
        # TODO: Optionally initializes this for supporting embeddings.
532
        self.vision_tower = init_vision_tower_for_llava(
533
534
535
            config,
            quant_config,
            require_post_norm=False,
536
            prefix=maybe_prefix(prefix, "vision_tower"))
537
538
539
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
540
            projector_hidden_act=config.projector_hidden_act,
541
            multimodal_projector_bias=config.multimodal_projector_bias,
542
543
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
544

545
        self.language_model = init_vllm_registered_model(
546
            vllm_config=vllm_config,
547
548
549
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
550

551
552
553
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

554
555
556
557
558
559
560
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
561
            raise ValueError(
562
563
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
564
565
566
567

        return data

    def _parse_and_validate_image_input(
568
569
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
570
        image_embeds = kwargs.pop("image_embeds", None)
571

572
        if pixel_values is None and image_embeds is None:
573
            return None
574

575
        if pixel_values is not None:
576
            if not isinstance(pixel_values, (torch.Tensor, list)):
577
578
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
579

580
            if self.config.vision_config.model_type == "pixtral":
581
582
583
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
584
585
                )

586
587
            return LlavaImagePixelInputs(
                type="pixel_values",
588
                pixel_values=self._validate_pixel_values(
589
                    flatten_bn(pixel_values, concat=True)),
590
591
592
            )

        if image_embeds is not None:
593
            if not isinstance(image_embeds, (torch.Tensor, list)):
594
595
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
596

597
598
599
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

600
601
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
602
                data=flatten_bn(image_embeds, concat=True),
603
604
605
            )

        raise AssertionError("This line should be unreachable.")
606
607
608
609
610
611
612
613
614
615
616

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

617
618
    def _image_pixels_to_features(
        self,
619
620
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
621
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
622
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
623
624
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
625
        image_features = vision_tower(pixel_values)
626

627
628
629
630
631
632
633
634
635
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
636
637
        )

638
639
640
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
641
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
642
643
        assert self.vision_tower is not None

644
        pixel_values = inputs["pixel_values"]
645
646
647

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

648
649
650
651
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
652
653
654
        if image_input["type"] == "image_embeds":
            return image_input["data"]

655
656
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
657

658
659
660
661
662
663
664
665
666
667
668
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

669
670
671
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

672
673
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
674
675
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
676
            return []
677

678
        return self._process_image_input(image_input)
679
680
681
682

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
683
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
684
685
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
686
        if multimodal_embeddings is not None:
687
            inputs_embeds = merge_multimodal_embeddings(
688
689
                input_ids,
                inputs_embeds,
690
                multimodal_embeddings,
691
692
                self.config.image_token_index,
            )
693
694
        return inputs_embeds

695
696
697
698
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
699
        intermediate_tensors: Optional[IntermediateTensors] = None,
700
        inputs_embeds: Optional[torch.Tensor] = None,
701
        **kwargs: object,
702
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
703
        """Run forward pass for LLaVA-1.5.
704
705
706

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
707

708
        Concretely, consider a text prompt:
709
710
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

711
        Tokenizer outputs:
712
713
714
715
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
716
        before they are inputted to the model, so the input processor prepends
717
718
719
720
721
722
723
724
725
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
726
727
728
729
730
731
732

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
733
            pixel_values: The pixels in each input image.
734

735
736
        Info:
            [LlavaImageInputs][]
737
        """
738
739
        if intermediate_tensors is not None:
            inputs_embeds = None
740
741
742

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
743
        elif inputs_embeds is None:
744
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
745
746
747
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
748

749
750
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
751
                                                  intermediate_tensors,
752
                                                  inputs_embeds=inputs_embeds)
753
754
755

        return hidden_states

756
757
758
759
760
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
761
762
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
763

764
765
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
766
        loader = AutoWeightsLoader(self)
767
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
768
769


770
771
class MantisProcessingInfo(LlavaProcessingInfo):

772
    def get_hf_processor(self, **kwargs: object):
773
774
775
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

776
777
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

778
779
780
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
781
            kwargs.setdefault("vision_feature_select_strategy", None)
782
783
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
784
785
786
787
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
788

789
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
790
791


792
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
793

794
795
    def apply(
        self,
796
        prompt: Union[str, list[int]],
797
798
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
799
        return_mm_hashes: bool = False,
800
    ) -> MultiModalInputs:
801
        hf_config = self.info.get_hf_config()
802
        image_token_id = hf_config.image_token_index
803
804

        # Assume that it doesn't depend on the image size
805
        num_image_tokens = self.info.get_num_image_tokens(
806
807
808
            image_width=-1,
            image_height=-1,
        )
809

810
811
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
812

813
814
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
815
        mm_kwargs = result["mm_kwargs"]
816
        mm_hashes = result["mm_hashes"]
817
818
819
820
821
822

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
823
                "<image>" * num_image_tokens,
824
825
826
                "</Image>)",  # 3 tokens
            ])

827
        mantis_mm_repls = self._bind_and_group_updates([
828
829
            PromptReplacement(
                modality="image",
830
                target=[image_token_id] * num_image_tokens,
831
832
833
834
                replacement=get_replacement_mantis,
            )
        ])

835
        prompt_ids, prompt, _ = self._apply_prompt_updates(
836
            result["prompt_token_ids"],
837
            mantis_mm_repls,
838
839
840
            mm_item_counts,
        )

841
        unbound_orig_repls = self._get_prompt_updates(
842
843
844
845
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
846
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
847
848
849
850
851
852
853

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
854

855
856
857
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
858
859
        }

860
        return MultiModalInputs(
861
            type="multimodal",
862
            prompt=prompt,
863
864
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
865
            mm_hashes=mm_hashes,
866
            mm_placeholders=mm_placeholder_ranges,
867
        )
868
869
870
871


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
872
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
873
                                        info=MantisProcessingInfo,
874
                                        dummy_inputs=LlavaDummyInputsBuilder)
875
876
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass