llava.py 29.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
7
                    Union)
8
9

import torch
10
import torch.nn as nn
11
12
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
13
                          SiglipVisionConfig)
14
from transformers.models.llava import LlavaProcessor
15
from transformers.models.pixtral import PixtralProcessor
16

17
from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
from vllm.model_executor.layers.activation import get_act_fn
20
21
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
22
from vllm.model_executor.layers.quantization import QuantizationConfig
23
from vllm.multimodal import MULTIMODAL_REGISTRY
24
from vllm.multimodal.cache import BaseMultiModalProcessorCache
25
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
26
27
                                    MultiModalInputs, MultiModalKwargsItems,
                                    MultiModalUUIDDict)
28
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
29
                                   ImageSize, MultiModalDataItems)
30
from vllm.multimodal.processing import (BaseMultiModalProcessor,
31
32
33
34
                                        BaseProcessingInfo,
                                        InputProcessingContext,
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
35
from vllm.multimodal.profiling import BaseDummyInputsBuilder
36
from vllm.sequence import IntermediateTensors
37
from vllm.utils.tensor_schema import TensorSchema, TensorShape
38

39
from .clip import CLIPVisionModel
40
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
41
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
42
from .siglip import SiglipVisionModel
43
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
44
                    init_vllm_registered_model, maybe_prefix)
45
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
46
47


48
class LlavaImagePixelInputs(TensorSchema):
49
    """
50
51
52
53
54
55
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    
56
57
58
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
59
60
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
61

62

63
class PixtralHFImagePixelInputs(TensorSchema):
64
    """
65
66
67
68
69
70
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
    
71
72
73
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
74
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
75
76
77
    pixel_values: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"})]
78

79

80
class LlavaImageEmbeddingInputs(TensorSchema):
81
    """
82
83
84
85
86
87
88
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
89
90


91
92
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
93
94


95
96
class LlavaMultiModalProjector(nn.Module):

97
98
99
100
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
101
                 multimodal_projector_bias: bool,
102
103
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
104
105
        super().__init__()

106
107
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
108
                                             bias=multimodal_projector_bias,
109
110
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
111
        self.act = get_act_fn(projector_hidden_act)
112
113
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
114
                                          bias=multimodal_projector_bias,
115
116
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
117

118
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
119
        hidden_states, _ = self.linear_1(image_features)
120
        hidden_states = self.act(hidden_states)
121
        hidden_states, _ = self.linear_2(hidden_states)
122
123
124
        return hidden_states


125
126
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
127
    image_token_index: Final[int]
128
    vision_feature_select_strategy: Final[str]
129
    vision_feature_layer: Final[Union[int, list[int]]]
130

131

132
133
134
135
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


136
class BaseLlavaProcessingInfo(BaseProcessingInfo):
137

138
    def get_hf_config(self) -> LlavaLikeConfig:
139
        return self.ctx.get_hf_config(LlavaConfig)
140

141
142
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
143

144
    @abstractmethod
145
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
146
        raise NotImplementedError
147

148
149
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
150

151
152
153
154
155
156
157
158
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
159

160
        return get_num_selected_vision_tokens(
161
162
163
164
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
165
            hf_config.vision_feature_select_strategy,
166
        )
167

168
169
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
170
171
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
172

173
174
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
175

176
        return self.get_num_image_tokens(
177
178
179
180
            image_width=target_width,
            image_height=target_height,
        )

181
182
183
184
185
186

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

187
188
189
190
191
192
193
194
195
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
196
        self,
197
        seq_len: int,
198
        mm_counts: Mapping[str, int],
199
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
200
    ) -> MultiModalDataDict:
201
202
        num_images = mm_counts.get("image", 0)

203
204
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
205

206
207
        image_overrides = mm_options.get("image") if mm_options else None

208
        return {
209
210
211
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
212
213
                                   num_images=num_images,
                                   overrides=image_overrides)
214
215
216
        }


217
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
218

219
    def get_hf_processor(self, **kwargs: object):
220
221
222
223
224
225
226
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
227
228


229
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
230
231
232
233
234
235
236
237
238

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
239

240
    def _get_prompt_updates(
241
242
243
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
244
        out_mm_kwargs: MultiModalKwargsItems,
245
    ) -> Sequence[PromptUpdate]:
246
        hf_config = self.info.get_hf_config()
247
248
249
250
251
252
253
254
255
256
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
257
                num_image_tokens = self.info.get_num_image_tokens(
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


273
274
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
275

276
277
278
279
280
281
282
283
284
285
286
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


287
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
288

289
290
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
291

292

293
294
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
295

296
297
298
299
300
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
301
        tok_kwargs: Mapping[str, object],
302
303
304
305
306
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
307
            tok_kwargs=tok_kwargs,
308
        )
309

310
311
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
312
313
314
315
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
316

317
318
319
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
320

321
        return processed_outputs
322

323
324
325
326
327
328
329
330
331
332
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

333
    def _get_prompt_updates(
334
335
        self,
        mm_items: MultiModalDataItems,
336
        hf_processor_mm_kwargs: Mapping[str, object],
337
        out_mm_kwargs: MultiModalKwargsItems,
338
    ) -> Sequence[PromptUpdate]:
339
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
340
        hf_config = self.info.get_hf_config()
341
342
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
343

344
345
346
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
347

348
349
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
350

351
352
353
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
354

355
            ncols, nrows = encoder_info.get_patch_grid_size(
356
357
358
                image_width=image_size.width,
                image_height=image_size.height,
            )
359

360
361
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
362

363
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
364
365
366
367
368

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
369
370
                replacement=get_replacement,
            ),
371
372
        ]

373

374
375
376
377
378
379
380
381
382
383
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


384
def _build_llava_or_pixtral_hf_processor(
385
386
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
387
    *,
388
    cache: Optional[BaseMultiModalProcessorCache] = None,
389
) -> BaseMultiModalProcessor:
390
    if isinstance(info, PixtralHFProcessingInfo):
391
        return PixtralHFMultiModalProcessor(
392
393
394
395
396
397
398
399
400
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
401
            cache=cache,
402
        )
403

404
    raise NotImplementedError(type(info))
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
428
    """Given a signed vision feature layer, get the number of hidden layers
429
430
431
432
433
434
435
436
437
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
438
    return feature_layer_index
439
440
441
442
443
444
445


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
446
    prefix: str = "",
447
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
448
449
    vision_config = hf_config.vision_config

450
451
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
452
453
454
455

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
456
            quant_config=quant_config,
457
            num_hidden_layers_override=num_hidden_layers,
458
            require_post_norm=require_post_norm,
459
            prefix=prefix,
460
461
462
463
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
464
            quant_config=quant_config,
465
            num_hidden_layers_override=num_hidden_layers,
466
            require_post_norm=require_post_norm,
467
            prefix=prefix,
468
        )
469
    elif isinstance(vision_config, PixtralVisionConfig):
470
471
        return PixtralHFVisionModel(
            vision_config,
472
            quant_config=quant_config,
473
474
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
475
            prefix=prefix,
476
        )
477
478
479
480
481

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


482
483
484
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
485
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
486
487
488
489

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
490
    }
491

492
493
494
495
496
497
498
499
500
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

501
502
503
504
505
506
507
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

508
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
509
        super().__init__()
510

511
512
513
514
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

515
        self.config = config
516
        self.multimodal_config = multimodal_config
517

518
519
520
521
522
523
524
525
526
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

527
        # TODO: Optionally initializes this for supporting embeddings.
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_tower = init_vision_tower_for_llava(
                config,
                quant_config,
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_tower"))
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"))
        else:
            self.vision_tower = None
            self.multi_modal_projector = None
544

545
        self.language_model = init_vllm_registered_model(
546
            vllm_config=vllm_config,
547
548
549
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
550

551
552
553
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

554
    def _parse_and_validate_image_input(
555
556
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
557
        image_embeds = kwargs.pop("image_embeds", None)
558

559
        if pixel_values is None and image_embeds is None:
560
            return None
561

562
        if pixel_values is not None:
563
            if not isinstance(pixel_values, (torch.Tensor, list)):
564
565
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
566

567
            if self.config.vision_config.model_type == "pixtral":
568
569
570
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
571
572
                )

573
            expected_h = expected_w = self.config.vision_config.image_size
574
575
            return LlavaImagePixelInputs(
                type="pixel_values",
576
577
578
579
580
                pixel_values=flatten_bn(pixel_values, concat=True),
                resolve_bindings={
                    "h": expected_h,
                    "w": expected_w
                },
581
582
583
            )

        if image_embeds is not None:
584
            if not isinstance(image_embeds, (torch.Tensor, list)):
585
586
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
587

588
589
590
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

591
592
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
593
                data=flatten_bn(image_embeds, concat=True),
594
595
596
            )

        raise AssertionError("This line should be unreachable.")
597

598
599
    def _image_pixels_to_features(
        self,
600
601
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
602
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
603
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
604
605
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
606
607
608
609
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
610

611
612
613
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
614
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
615
616
        assert self.vision_tower is not None

617
        pixel_values = inputs["pixel_values"]
618
619
620

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

621
622
623
624
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
625
626
627
        if image_input["type"] == "image_embeds":
            return image_input["data"]

628
629
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
630

631
632
633
634
635
636
637
638
639
640
641
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

642
643
644
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

645
646
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
647
648
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
649
            return []
650

651
        return self._process_image_input(image_input)
652

653
654
655
656
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
657
        intermediate_tensors: Optional[IntermediateTensors] = None,
658
        inputs_embeds: Optional[torch.Tensor] = None,
659
        **kwargs: object,
660
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
661
        """Run forward pass for LLaVA-1.5.
662
663
664

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
665

666
        Concretely, consider a text prompt:
667
668
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

669
        Tokenizer outputs:
670
671
672
673
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
674
        before they are inputted to the model, so the input processor prepends
675
676
677
678
679
680
681
682
683
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
684
685
686
687
688
689
690

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
691
692
693
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
694

695
        Info:
samzong's avatar
samzong committed
696
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
697
        """
698
699
        if intermediate_tensors is not None:
            inputs_embeds = None
700

701
702
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
703
                                                  intermediate_tensors,
704
                                                  inputs_embeds=inputs_embeds)
705
706
707

        return hidden_states

708
709
710
711
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
712
        return self.language_model.compute_logits(hidden_states)
713

714
715
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
716
717
718
719
720
        skip_prefixes = []
        if self.vision_tower is None and self.multi_modal_projector is None:
            skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])

        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
721
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
722
723


724
725
class MantisProcessingInfo(LlavaProcessingInfo):

726
    def get_hf_processor(self, **kwargs: object):
727
728
729
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

730
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
731
732
733
734
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
735

736
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
737
738


739
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
740

741
742
    def apply(
        self,
743
        prompt: Union[str, list[int]],
744
745
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
746
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
747
        mm_uuids: Optional[MultiModalUUIDDict] = None,
748
    ) -> MultiModalInputs:
749
        hf_config = self.info.get_hf_config()
750
        image_token_id = hf_config.image_token_index
751
752

        # Assume that it doesn't depend on the image size
753
        num_image_tokens = self.info.get_num_image_tokens(
754
755
756
            image_width=-1,
            image_height=-1,
        )
757

758
759
760
761
        result = super().apply(prompt,
                               mm_data,
                               hf_processor_mm_kwargs,
                               tokenization_kwargs,
762
                               mm_uuids=mm_uuids)
763

764
765
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
766
        mm_kwargs = result["mm_kwargs"]
767
        mm_hashes = result["mm_hashes"]
768
769
770
771
772
773

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
774
                "<image>" * num_image_tokens,
775
776
777
                "</Image>)",  # 3 tokens
            ])

778
        mantis_mm_repls = self._bind_and_group_updates([
779
780
            PromptReplacement(
                modality="image",
781
                target=[image_token_id] * num_image_tokens,
782
783
                replacement=get_replacement_mantis,
            )
784
        ], mm_item_counts)
785

786
        prompt_ids, _ = self._apply_prompt_updates(
787
            result["prompt_token_ids"],
788
            mantis_mm_repls,
789
790
        )

791
        orig_repls = self._get_mm_prompt_updates(
792
793
794
795
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
796
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
797
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
798

799
800
801
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
802
803
        }

804
        return MultiModalInputs(
805
806
807
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
808
            mm_hashes=mm_hashes,
809
            mm_placeholders=mm_placeholder_ranges,
810
        )
811
812
813
814


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
815
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
816
                                        info=MantisProcessingInfo,
817
                                        dummy_inputs=LlavaDummyInputsBuilder)
818
819
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass