llava.py 32.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
7
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
                    Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.jsontree import json_map_leaves
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
29
                                    MultiModalInputs, MultiModalKwargs)
30
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
31
                                   ImageSize, MultiModalDataItems)
32
from vllm.multimodal.processing import (BaseMultiModalProcessor,
33
                                        BaseProcessingInfo, ProcessingCache,
34
35
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder
37
from vllm.sequence import IntermediateTensors
38

39
from .clip import CLIPVisionModel
40
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
41
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
42
from .siglip import SiglipVisionModel
43
44
45
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
46
from .vision import get_vision_encoder_info
47
48


49
50
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
51
    pixel_values: torch.Tensor
52
53
54
55
56
57
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
58

59
60
61
62
63
64
65
66
67
68
69

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

70
71
72
73

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
74
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
75
76
77
78
79

    `hidden_size` must match the hidden size of language model backbone.
    """


80
81
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
82
83


84
85
class LlavaMultiModalProjector(nn.Module):

86
87
88
89
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
90
                 multimodal_projector_bias: bool,
91
92
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
93
94
        super().__init__()

95
96
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
97
                                             bias=multimodal_projector_bias,
98
99
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
100
        self.act = get_act_fn(projector_hidden_act)
101
102
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
103
                                          bias=multimodal_projector_bias,
104
105
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
106

107
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
108
        hidden_states, _ = self.linear_1(image_features)
109
        hidden_states = self.act(hidden_states)
110
        hidden_states, _ = self.linear_2(hidden_states)
111
112
113
        return hidden_states


114
115
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
116
    image_token_index: Final[int]
117
    vision_feature_select_strategy: Final[str]
118
    vision_feature_layer: Final[Union[int, list[int]]]
119

120

121
122
123
124
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


125
class BaseLlavaProcessingInfo(BaseProcessingInfo):
126

127
    def get_hf_config(self) -> LlavaLikeConfig:
128
        return self.ctx.get_hf_config(LlavaConfig)
129

130
131
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
132

133
    @abstractmethod
134
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
135
        raise NotImplementedError
136

137
138
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
139

140
141
142
143
144
145
146
147
148
149
150
151
152
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

153
154
155
156
157
158
159
160
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
161

162
163
164
165
166
167
168
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
169

170
171
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
172
173
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
174

175
176
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
177

178
        return self.get_num_image_tokens(
179
180
181
182
            image_width=target_width,
            image_height=target_height,
        )

183
184
185
186
187
188

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

189
190
191
192
193
194
195
196
197
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
198
        self,
199
        seq_len: int,
200
        mm_counts: Mapping[str, int],
201
    ) -> MultiModalDataDict:
202
203
        num_images = mm_counts.get("image", 0)

204
205
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
206

207
        return {
208
209
210
211
212
213
214
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


215
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
216

217
    def get_hf_processor(self, **kwargs: object):
218
219
220
221
222
223
224
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
225
226


227
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
228
229
230
231
232
233
234
235
236

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
237

238
    def _get_prompt_updates(
239
240
241
242
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
243
    ) -> Sequence[PromptUpdate]:
244
        hf_config = self.info.get_hf_config()
245
246
247
248
249
250
251
252
253
254
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
255
                num_image_tokens = self.info.get_num_image_tokens(
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


271
272
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
273

274
275
276
277
278
279
280
281
282
283
284
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


285
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
286

287
288
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
289

290

291
292
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
293

294
295
296
297
298
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
299
        tok_kwargs: Mapping[str, object],
300
301
302
303
304
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
305
            tok_kwargs=tok_kwargs,
306
        )
307

308
309
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
310
            # Before/after https://github.com/huggingface/transformers/pull/35122
311
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
333

334
        return processed_outputs
335

336
337
338
339
340
341
342
343
344
345
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

346
    def _get_prompt_updates(
347
348
        self,
        mm_items: MultiModalDataItems,
349
350
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
351
    ) -> Sequence[PromptUpdate]:
352
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
353
        hf_config = self.info.get_hf_config()
354
355
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
356

357
358
359
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
360

361
362
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
363

364
365
366
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
367

368
            ncols, nrows = encoder_info.get_patch_grid_size(
369
370
371
                image_width=image_size.width,
                image_height=image_size.height,
            )
372

373
374
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
375

376
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
377
378
379
380
381

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
382
383
                replacement=get_replacement,
            ),
384
385
        ]

386

387
388
389
390
391
392
393
394
395
396
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


397
def _build_llava_or_pixtral_hf_processor(
398
399
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
400
401
    *,
    cache: Optional[ProcessingCache] = None,
402
) -> BaseMultiModalProcessor:
403
    if isinstance(info, PixtralHFProcessingInfo):
404
        return PixtralHFMultiModalProcessor(
405
406
407
408
409
410
411
412
413
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
414
            cache=cache,
415
        )
416

417
    raise NotImplementedError(type(info))
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
441
    """Given a signed vision feature layer, get the number of hidden layers
442
443
444
445
446
447
448
449
450
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
451
    return feature_layer_index
452
453
454
455
456
457
458


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
459
    prefix: str = "",
460
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
461
462
    vision_config = hf_config.vision_config

463
464
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
465
466
467
468

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
469
            quant_config=quant_config,
470
            num_hidden_layers_override=num_hidden_layers,
471
            require_post_norm=require_post_norm,
472
            prefix=prefix,
473
474
475
476
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
477
            quant_config=quant_config,
478
            num_hidden_layers_override=num_hidden_layers,
479
            require_post_norm=require_post_norm,
480
            prefix=prefix,
481
        )
482
    elif isinstance(vision_config, PixtralVisionConfig):
483
484
        return PixtralHFVisionModel(
            vision_config,
485
            quant_config=quant_config,
486
487
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
488
            prefix=prefix,
489
        )
490
491
492
493
494

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


495
496
497
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
498
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
499
500
501
502

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
503
    }
504

505
506
507
508
509
510
511
512
513
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

514
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
515
        super().__init__()
516

517
518
519
520
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

521
        self.config = config
522
        self.multimodal_config = multimodal_config
523

524
525
526
527
528
529
530
531
532
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

533
        # TODO: Optionally initializes this for supporting embeddings.
534
        self.vision_tower = init_vision_tower_for_llava(
535
536
537
            config,
            quant_config,
            require_post_norm=False,
538
            prefix=maybe_prefix(prefix, "vision_tower"))
539
540
541
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
542
            projector_hidden_act=config.projector_hidden_act,
543
            multimodal_projector_bias=config.multimodal_projector_bias,
544
545
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
546

547
        self.language_model = init_vllm_registered_model(
548
            vllm_config=vllm_config,
549
550
551
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
552

553
554
555
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

556
557
558
559
560
561
562
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
563
            raise ValueError(
564
565
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
566
567
568
569

        return data

    def _parse_and_validate_image_input(
570
571
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
572
        image_embeds = kwargs.pop("image_embeds", None)
573

574
        if pixel_values is None and image_embeds is None:
575
            return None
576

577
        if pixel_values is not None:
578
            if not isinstance(pixel_values, (torch.Tensor, list)):
579
580
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
581

582
            if self.config.vision_config.model_type == "pixtral":
583
584
585
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
586
587
                )

588
589
            return LlavaImagePixelInputs(
                type="pixel_values",
590
                pixel_values=self._validate_pixel_values(
591
                    flatten_bn(pixel_values, concat=True)),
592
593
594
            )

        if image_embeds is not None:
595
            if not isinstance(image_embeds, (torch.Tensor, list)):
596
597
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
598

599
600
601
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

602
603
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
604
                data=flatten_bn(image_embeds, concat=True),
605
606
607
            )

        raise AssertionError("This line should be unreachable.")
608
609
610
611
612
613
614
615
616
617
618

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

619
620
    def _image_pixels_to_features(
        self,
621
622
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
623
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
624
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
625
626
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
627
        image_features = vision_tower(pixel_values)
628

629
630
631
632
633
634
635
636
637
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
638
639
        )

640
641
642
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
643
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
644
645
        assert self.vision_tower is not None

646
        pixel_values = inputs["pixel_values"]
647
648
649

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

650
651
652
653
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
654
655
656
        if image_input["type"] == "image_embeds":
            return image_input["data"]

657
658
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
659

660
661
662
663
664
665
666
667
668
669
670
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

671
672
673
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

674
675
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
676
677
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
678
            return []
679

680
        return self._process_image_input(image_input)
681
682
683
684

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
685
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
686
687
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
688
689
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
690
            inputs_embeds = merge_multimodal_embeddings(
691
692
                input_ids,
                inputs_embeds,
693
                multimodal_embeddings,
694
695
                self.config.image_token_index,
            )
696
697
        return inputs_embeds

698
699
700
701
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
702
        intermediate_tensors: Optional[IntermediateTensors] = None,
703
        inputs_embeds: Optional[torch.Tensor] = None,
704
        **kwargs: object,
705
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
706
        """Run forward pass for LLaVA-1.5.
707
708
709

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
710

711
        Concretely, consider a text prompt:
712
713
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

714
        Tokenizer outputs:
715
716
717
718
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
719
        before they are inputted to the model, so the input processor prepends
720
721
722
723
724
725
726
727
728
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
729
730
731
732
733
734
735

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
736
            pixel_values: The pixels in each input image.
737

738
739
        Info:
            [LlavaImageInputs][]
740
        """
741
742
        if intermediate_tensors is not None:
            inputs_embeds = None
743
744
745

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
746
        elif inputs_embeds is None:
747
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
748
749
750
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
751

752
753
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
754
                                                  intermediate_tensors,
755
                                                  inputs_embeds=inputs_embeds)
756
757
758

        return hidden_states

759
760
761
762
763
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
764
765
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
766

767
768
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
769
        loader = AutoWeightsLoader(self)
770
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
771
772


773
774
class MantisProcessingInfo(LlavaProcessingInfo):

775
    def get_hf_processor(self, **kwargs: object):
776
777
778
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

779
780
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

781
782
783
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
784
            kwargs.setdefault("vision_feature_select_strategy", None)
785
786
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
787
788
789
790
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
791

792
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
793
794


795
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
796

797
798
    def apply(
        self,
799
        prompt: Union[str, list[int]],
800
801
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
802
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
803
        return_mm_hashes: bool = False,
804
    ) -> MultiModalInputs:
805
        hf_config = self.info.get_hf_config()
806
        image_token_id = hf_config.image_token_index
807
808

        # Assume that it doesn't depend on the image size
809
        num_image_tokens = self.info.get_num_image_tokens(
810
811
812
            image_width=-1,
            image_height=-1,
        )
813

814
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
815
                               tokenization_kwargs, return_mm_hashes)
816

817
818
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
819
        mm_kwargs = result["mm_kwargs"]
820
        mm_hashes = result["mm_hashes"]
821
822
823
824
825
826

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
827
                "<image>" * num_image_tokens,
828
829
830
                "</Image>)",  # 3 tokens
            ])

831
        mantis_mm_repls = self._bind_and_group_updates([
832
833
            PromptReplacement(
                modality="image",
834
                target=[image_token_id] * num_image_tokens,
835
836
837
838
                replacement=get_replacement_mantis,
            )
        ])

839
        prompt_ids, prompt, _ = self._apply_prompt_updates(
840
            result["prompt_token_ids"],
841
            mantis_mm_repls,
842
843
844
            mm_item_counts,
        )

845
        unbound_orig_repls = self._get_prompt_updates(
846
847
848
849
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
850
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
851
852
853
854
855
856
857

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
858

859
860
861
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
862
863
        }

864
        return MultiModalInputs(
865
            type="multimodal",
866
            prompt=prompt,
867
868
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
869
            mm_hashes=mm_hashes,
870
            mm_placeholders=mm_placeholder_ranges,
871
        )
872
873
874
875


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
876
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
877
                                        info=MantisProcessingInfo,
878
                                        dummy_inputs=LlavaDummyInputsBuilder)
879
880
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass