"vscode:/vscode.git/clone" did not exist on "fba89069302e9b4d0457bc8eeddeeec76f27f0b1"
llava.py 31.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
7
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
                    Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.jsontree import json_map_leaves
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
29
                                    MultiModalInputs, MultiModalKwargs)
30
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
31
                                   ImageSize, MultiModalDataItems)
32
from vllm.multimodal.processing import (BaseMultiModalProcessor,
33
                                        BaseProcessingInfo, ProcessingCache,
34
35
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder
37
from vllm.sequence import IntermediateTensors
38

39
from .clip import CLIPVisionModel
40
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
41
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
42
from .siglip import SiglipVisionModel
43
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
44
                    maybe_prefix, merge_multimodal_embeddings)
45
from .vision import get_vision_encoder_info
46
47


48
49
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
50
    pixel_values: torch.Tensor
51
52
53
54
55
56
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
57

58
59
60
61
62
63
64
65
66
67
68

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

69
70
71
72

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
73
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
74
75
76
77
78

    `hidden_size` must match the hidden size of language model backbone.
    """


79
80
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
81
82


83
84
class LlavaMultiModalProjector(nn.Module):

85
86
87
88
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
89
                 multimodal_projector_bias: bool,
90
91
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
92
93
        super().__init__()

94
95
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
96
                                             bias=multimodal_projector_bias,
97
98
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
99
        self.act = get_act_fn(projector_hidden_act)
100
101
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
102
                                          bias=multimodal_projector_bias,
103
104
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
105

106
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
107
        hidden_states, _ = self.linear_1(image_features)
108
        hidden_states = self.act(hidden_states)
109
        hidden_states, _ = self.linear_2(hidden_states)
110
111
112
        return hidden_states


113
114
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
115
    image_token_index: Final[int]
116
    vision_feature_select_strategy: Final[str]
117
    vision_feature_layer: Final[Union[int, list[int]]]
118

119

120
121
122
123
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


124
class BaseLlavaProcessingInfo(BaseProcessingInfo):
125

126
    def get_hf_config(self) -> LlavaLikeConfig:
127
        return self.ctx.get_hf_config(LlavaConfig)
128

129
130
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
131

132
    @abstractmethod
133
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
134
        raise NotImplementedError
135

136
137
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
138

139
140
141
142
143
144
145
146
147
148
149
150
151
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

152
153
154
155
156
157
158
159
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
160

161
162
163
164
165
166
167
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
168

169
170
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
171
172
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
173

174
175
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
176

177
        return self.get_num_image_tokens(
178
179
180
181
            image_width=target_width,
            image_height=target_height,
        )

182
183
184
185
186
187

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

188
189
190
191
192
193
194
195
196
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
197
        self,
198
        seq_len: int,
199
        mm_counts: Mapping[str, int],
200
    ) -> MultiModalDataDict:
201
202
        num_images = mm_counts.get("image", 0)

203
204
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
205

206
        return {
207
208
209
210
211
212
213
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


214
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
215

216
    def get_hf_processor(self, **kwargs: object):
217
218
219
220
221
222
223
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
224
225


226
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
227
228
229
230
231
232
233
234
235

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
236

237
    def _get_prompt_updates(
238
239
240
241
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
242
    ) -> Sequence[PromptUpdate]:
243
        hf_config = self.info.get_hf_config()
244
245
246
247
248
249
250
251
252
253
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
254
                num_image_tokens = self.info.get_num_image_tokens(
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


270
271
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
272

273
274
275
276
277
278
279
280
281
282
283
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


284
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
285

286
287
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
288

289

290
291
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
292

293
294
295
296
297
298
299
300
301
302
303
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
304

305
306
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
307
            # Before/after https://github.com/huggingface/transformers/pull/35122
308
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
330

331
        return processed_outputs
332

333
334
335
336
337
338
339
340
341
342
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

343
    def _get_prompt_updates(
344
345
        self,
        mm_items: MultiModalDataItems,
346
347
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
348
    ) -> Sequence[PromptUpdate]:
349
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
350
        hf_config = self.info.get_hf_config()
351
352
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
353

354
355
356
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
357

358
359
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
360

361
362
363
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
364

365
            ncols, nrows = encoder_info.get_patch_grid_size(
366
367
368
                image_width=image_size.width,
                image_height=image_size.height,
            )
369

370
371
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
372

373
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
374
375
376
377
378

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
379
380
                replacement=get_replacement,
            ),
381
382
        ]

383

384
385
386
387
388
389
390
391
392
393
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


394
def _build_llava_or_pixtral_hf_processor(
395
396
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
397
398
    *,
    cache: Optional[ProcessingCache] = None,
399
) -> BaseMultiModalProcessor:
400
    if isinstance(info, PixtralHFProcessingInfo):
401
        return PixtralHFMultiModalProcessor(
402
403
404
405
406
407
408
409
410
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
411
            cache=cache,
412
        )
413

414
    raise NotImplementedError(type(info))
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
438
    """Given a signed vision feature layer, get the number of hidden layers
439
440
441
442
443
444
445
446
447
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
448
    return feature_layer_index
449
450
451
452
453
454
455


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
456
    prefix: str = "",
457
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
458
459
    vision_config = hf_config.vision_config

460
461
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
462
463
464
465

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
466
            quant_config=quant_config,
467
            num_hidden_layers_override=num_hidden_layers,
468
            require_post_norm=require_post_norm,
469
            prefix=prefix,
470
471
472
473
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
474
            quant_config=quant_config,
475
            num_hidden_layers_override=num_hidden_layers,
476
            require_post_norm=require_post_norm,
477
            prefix=prefix,
478
        )
479
    elif isinstance(vision_config, PixtralVisionConfig):
480
481
        return PixtralHFVisionModel(
            vision_config,
482
            quant_config=quant_config,
483
484
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
485
            prefix=prefix,
486
        )
487
488
489
490
491

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


492
493
494
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
495
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
496
497
498
499

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
500
    }
501

502
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
503
        super().__init__()
504

505
506
507
508
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

509
        self.config = config
510
        self.multimodal_config = multimodal_config
511

512
513
514
515
516
517
518
519
520
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

521
        # TODO: Optionally initializes this for supporting embeddings.
522
        self.vision_tower = init_vision_tower_for_llava(
523
524
525
            config,
            quant_config,
            require_post_norm=False,
526
            prefix=maybe_prefix(prefix, "vision_tower"))
527
528
529
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
530
            projector_hidden_act=config.projector_hidden_act,
531
            multimodal_projector_bias=config.multimodal_projector_bias,
532
533
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
534

535
        self.language_model = init_vllm_registered_model(
536
            vllm_config=vllm_config,
537
538
539
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
540

541
542
543
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

544
545
546
547
548
549
550
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
551
            raise ValueError(
552
553
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
554
555
556
557

        return data

    def _parse_and_validate_image_input(
558
559
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
560
        image_embeds = kwargs.pop("image_embeds", None)
561

562
        if pixel_values is None and image_embeds is None:
563
            return None
564

565
        if pixel_values is not None:
566
            if not isinstance(pixel_values, (torch.Tensor, list)):
567
568
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
569

570
            if self.config.vision_config.model_type == "pixtral":
571
572
573
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
574
575
                )

576
577
            return LlavaImagePixelInputs(
                type="pixel_values",
578
                pixel_values=self._validate_pixel_values(
579
                    flatten_bn(pixel_values, concat=True)),
580
581
582
            )

        if image_embeds is not None:
583
            if not isinstance(image_embeds, (torch.Tensor, list)):
584
585
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
586

587
588
589
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

590
591
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
592
                data=flatten_bn(image_embeds, concat=True),
593
594
595
            )

        raise AssertionError("This line should be unreachable.")
596
597
598
599
600
601
602
603
604
605
606

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

607
608
    def _image_pixels_to_features(
        self,
609
610
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
611
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
612
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
613
614
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
615
        image_features = vision_tower(pixel_values)
616

617
618
619
620
621
622
623
624
625
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
626
627
        )

628
629
630
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
631
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
632
633
        assert self.vision_tower is not None

634
        pixel_values = inputs["pixel_values"]
635
636
637

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

638
639
640
641
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
642
643
644
        if image_input["type"] == "image_embeds":
            return image_input["data"]

645
646
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
647

648
649
650
651
652
653
654
655
656
657
658
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

659
660
661
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

662
663
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
664
665
666
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
667

668
        return self._process_image_input(image_input)
669
670
671
672

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
673
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
674
675
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
676
        if multimodal_embeddings is not None:
677
            inputs_embeds = merge_multimodal_embeddings(
678
679
                input_ids,
                inputs_embeds,
680
                multimodal_embeddings,
681
682
                self.config.image_token_index,
            )
683
684
        return inputs_embeds

685
686
687
688
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
689
        intermediate_tensors: Optional[IntermediateTensors] = None,
690
        inputs_embeds: Optional[torch.Tensor] = None,
691
        **kwargs: object,
692
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
693
        """Run forward pass for LLaVA-1.5.
694
695
696

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
697

698
        Concretely, consider a text prompt:
699
700
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

701
        Tokenizer outputs:
702
703
704
705
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
706
        before they are inputted to the model, so the input processor prepends
707
708
709
710
711
712
713
714
715
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
716
717
718
719
720
721
722

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
723
            pixel_values: The pixels in each input image.
724

725
726
        Info:
            [LlavaImageInputs][]
727
        """
728
729
        if intermediate_tensors is not None:
            inputs_embeds = None
730
731
732

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
733
        elif inputs_embeds is None:
734
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
735
736
737
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
738

739
740
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
741
                                                  intermediate_tensors,
742
                                                  inputs_embeds=inputs_embeds)
743
744
745

        return hidden_states

746
747
748
749
750
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
751
752
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
753

754
755
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
756
        loader = AutoWeightsLoader(self)
757
        return loader.load_weights(weights)
758
759


760
761
class MantisProcessingInfo(LlavaProcessingInfo):

762
    def get_hf_processor(self, **kwargs: object):
763
764
765
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

766
767
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

768
769
770
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
771
            kwargs.setdefault("vision_feature_select_strategy", None)
772
773
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
774
775
776
777
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
778

779
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
780
781


782
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
783

784
785
    def apply(
        self,
786
        prompt: Union[str, list[int]],
787
788
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
789
        return_mm_hashes: bool = False,
790
    ) -> MultiModalInputs:
791
        hf_config = self.info.get_hf_config()
792
        image_token_id = hf_config.image_token_index
793
794

        # Assume that it doesn't depend on the image size
795
        num_image_tokens = self.info.get_num_image_tokens(
796
797
798
            image_width=-1,
            image_height=-1,
        )
799

800
801
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
802

803
804
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
805
        mm_kwargs = result["mm_kwargs"]
806
        mm_hashes = result["mm_hashes"]
807
808
809
810
811
812

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
813
                "<image>" * num_image_tokens,
814
815
816
                "</Image>)",  # 3 tokens
            ])

817
        mantis_mm_repls = self._bind_and_group_updates([
818
819
            PromptReplacement(
                modality="image",
820
                target=[image_token_id] * num_image_tokens,
821
822
823
824
                replacement=get_replacement_mantis,
            )
        ])

825
        prompt_ids, prompt, _ = self._apply_prompt_updates(
826
            result["prompt_token_ids"],
827
            mantis_mm_repls,
828
829
830
            mm_item_counts,
        )

831
        unbound_orig_repls = self._get_prompt_updates(
832
833
834
835
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
836
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
837
838
839
840
841
842
843

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
844

845
846
847
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
848
849
        }

850
        return MultiModalInputs(
851
            type="multimodal",
852
            prompt=prompt,
853
854
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
855
            mm_hashes=mm_hashes,
856
            mm_placeholders=mm_placeholder_ranges,
857
        )
858
859
860
861


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
862
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
863
                                        info=MantisProcessingInfo,
864
                                        dummy_inputs=LlavaDummyInputsBuilder)
865
866
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass