llava.py 34.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
7
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
                    TypeVar, Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.jsontree import json_map_leaves
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
26
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
30
                                    MultiModalInputs, MultiModalKwargs)
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
                                   ImageSize, MultiModalDataItems)
33
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
                                        BaseProcessingInfo, ProcessingCache,
35
                                        PromptReplacement, PromptUpdate)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
from vllm.sequence import IntermediateTensors
38
from vllm.utils import flatten_2d_lists
39

40
from .clip import CLIPVisionModel
41
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
42
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
43
from .siglip import SiglipVisionModel
44
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
45
                    maybe_prefix, merge_multimodal_embeddings)
46
47
from .vision import (get_vision_encoder_info, scatter_patch_features,
                     select_patch_features)
48
49


50
51
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
52
    pixel_values: torch.Tensor
53
54
55
56
57
58
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
59

60
61
62
63
64
65
66
67
68
69
70
71

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
72
73
74
75
    """
    A boolean mask indicating which image embeddings correspond
    to patch tokens.
    
76
    Shape: `(batch_size, num_images, num_embeds)`
77
78
    """

79
    num_embeds: Union[torch.Tensor, list[torch.Tensor]]
80
81
    """Shape: `(batch_size, num_images)`"""

82
83
84
85

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
86
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
87
88
89
90
91

    `hidden_size` must match the hidden size of language model backbone.
    """


92
93
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
94
95


96
97
class LlavaMultiModalProjector(nn.Module):

98
99
100
101
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
102
                 multimodal_projector_bias: bool,
103
104
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
105
106
        super().__init__()

107
108
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
109
                                             bias=multimodal_projector_bias,
110
111
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
112
        self.act = get_act_fn(projector_hidden_act)
113
114
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
115
                                          bias=multimodal_projector_bias,
116
117
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
118

119
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
120
        hidden_states, _ = self.linear_1(image_features)
121
        hidden_states = self.act(hidden_states)
122
        hidden_states, _ = self.linear_2(hidden_states)
123
124
125
        return hidden_states


126
127
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
128
    image_token_index: Final[int]
129
    vision_feature_select_strategy: Final[str]
130
    vision_feature_layer: Final[Union[int, list[int]]]
131

132

133
134
135
136
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


137
class BaseLlavaProcessingInfo(BaseProcessingInfo):
138

139
    def get_hf_config(self) -> LlavaLikeConfig:
140
        return self.ctx.get_hf_config(LlavaConfig)
141

142
143
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
144

145
    @abstractmethod
146
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
147
        raise NotImplementedError
148

149
150
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
151

152
153
154
155
156
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
157
        return {"image": self.get_max_image_tokens()}
158

159
160
161
162
163
164
165
166
167
168
169
170
171
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

172
173
174
175
176
177
178
179
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
180

181
182
183
184
185
186
187
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
188

189
190
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
191
192
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
193

194
195
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
196

197
        return self.get_num_image_tokens(
198
199
200
201
            image_width=target_width,
            image_height=target_height,
        )

202
203
204
205
206
207

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

208
    def get_dummy_processor_inputs(
209
        self,
210
        seq_len: int,
211
212
213
214
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

215
        processor = self.info.get_hf_processor()
216
        image_token = processor.image_token
217
218
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
219
220
221
222
223
224
225
226
227
228
229
230
231
232

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


233
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
234

235
236
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
237
238


239
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
240
241
242
243
244
245
246
247
248

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
249

250
    def _get_prompt_updates(
251
252
253
254
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
255
    ) -> Sequence[PromptUpdate]:
256
        hf_config = self.info.get_hf_config()
257
258
259
260
261
262
263
264
265
266
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
267
                num_image_tokens = self.info.get_num_image_tokens(
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


283
284
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
285

286
287
288
289
290
291
292
293
294
295
296
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


297
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
298

299
300
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
301

302

303
304
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
305

306
307
308
309
310
311
312
313
314
315
316
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
317

318
319
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
320
            # Before/after https://github.com/huggingface/transformers/pull/35122
321
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
343

344
            hf_config = self.info.get_hf_config()
345
346
347
            vision_config = hf_config.vision_config
            assert isinstance(vision_config, PixtralVisionConfig)
            encoder_info = PixtralHFEncoderInfo(vision_config)
348
349

            tile_sizes = [
350
                encoder_info.get_patch_grid_size(
351
                    image_width=pixel_value.shape[-1],
352
353
                    image_height=pixel_value.shape[-2],
                ) for pixel_value in processed_outputs["pixel_values"]
354
            ]
355
356
            num_embeds = torch.tensor([(ncols + 1) * nrows
                                       for ncols, nrows in tile_sizes])
357
            # Each image may result to masks of different sizes, so we need to
358
            # later use `num_embeds` to get per-image masks.
359
360
361
362
            embed_is_patch = [
                torch.tensor(([True] * ncols + [False]) * nrows)
                for ncols, nrows in tile_sizes
            ]
363
            processed_outputs["num_embeds"] = num_embeds
364
365
            processed_outputs["embed_is_patch"] = embed_is_patch

366
        return processed_outputs
367

368
369
370
371
372
373
374
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
375
            num_embeds=MultiModalFieldConfig.batched("image"),
376
            embed_is_patch=MultiModalFieldConfig.batched("image"),
377
378
379
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

380
    def _get_prompt_updates(
381
382
        self,
        mm_items: MultiModalDataItems,
383
384
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
385
    ) -> Sequence[PromptUpdate]:
386
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
387
        hf_config = self.info.get_hf_config()
388
389
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
390

391
392
393
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
394

395
396
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
397
        encoder_info = PixtralHFEncoderInfo(vision_config)
398

399
400
401
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
402

403
            ncols, nrows = encoder_info.get_patch_grid_size(
404
405
406
                image_width=image_size.width,
                image_height=image_size.height,
            )
407

408
409
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
410

411
            return tokens
412
413
414
415
416

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
417
418
                replacement=get_replacement,
            ),
419
420
        ]

421

422
423
424
425
426
427
428
429
430
431
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


432
def _build_llava_or_pixtral_hf_processor(
433
434
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
435
436
437
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
438
) -> BaseMultiModalProcessor:
439
    if isinstance(info, PixtralHFProcessingInfo):
440
        return PixtralHFMultiModalProcessor(
441
442
443
444
445
446
447
448
449
450
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
451
452
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
453
        )
454

455
    raise NotImplementedError(type(info))
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
479
    """Given a signed vision feature layer, get the number of hidden layers
480
481
482
483
484
485
486
487
488
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
489
    return feature_layer_index
490
491
492
493
494
495
496


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
497
    prefix: str = "",
498
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
499
500
    vision_config = hf_config.vision_config

501
502
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
503
504
505
506

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
507
            quant_config=quant_config,
508
            num_hidden_layers_override=num_hidden_layers,
509
            require_post_norm=require_post_norm,
510
            prefix=prefix,
511
512
513
514
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
515
            quant_config=quant_config,
516
            num_hidden_layers_override=num_hidden_layers,
517
            require_post_norm=require_post_norm,
518
            prefix=prefix,
519
        )
520
    elif isinstance(vision_config, PixtralVisionConfig):
521
522
        return PixtralHFVisionModel(
            vision_config,
523
            quant_config=quant_config,
524
525
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
526
            prefix=prefix,
527
        )
528
529
530
531
532

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


533
534
535
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
536
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
537
538
539
540

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
541
    }
542

543
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
544
        super().__init__()
545

546
547
548
549
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

550
        self.config = config
551
        self.multimodal_config = multimodal_config
552

553
554
555
556
557
558
559
560
561
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

562
        # TODO: Optionally initializes this for supporting embeddings.
563
        self.vision_tower = init_vision_tower_for_llava(
564
565
566
            config,
            quant_config,
            require_post_norm=False,
567
            prefix=maybe_prefix(prefix, "vision_tower"))
568
569
570
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
571
            projector_hidden_act=config.projector_hidden_act,
572
            multimodal_projector_bias=config.multimodal_projector_bias,
573
574
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
575

576
        self.language_model = init_vllm_registered_model(
577
            vllm_config=vllm_config,
578
579
580
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
581

582
583
584
585
586
587
588
589
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
590
        return get_sampler()
591

592
593
594
595
596
597
598
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
599
            raise ValueError(
600
601
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
602
603
604
605

        return data

    def _parse_and_validate_image_input(
606
607
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
608
        image_embeds = kwargs.pop("image_embeds", None)
609

610
        if pixel_values is None and image_embeds is None:
611
            return None
612

613
        if pixel_values is not None:
614
            if not isinstance(pixel_values, (torch.Tensor, list)):
615
616
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
617

618
            if self.config.vision_config.model_type == "pixtral":
619
620
621
622
623
                embed_is_patch = kwargs.pop("embed_is_patch")
                if not isinstance(embed_is_patch, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of embed_is_patch. "
                                     f"Got type: {type(embed_is_patch)}")

624
625
626
627
                num_embeds = kwargs.pop("num_embeds")
                if not isinstance(num_embeds, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of num_embeds. "
                                     f"Got type: {type(num_embeds)}")
628
629
630
631

                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
632
                    embed_is_patch=embed_is_patch,
633
                    num_embeds=num_embeds,
634
635
                )

636
637
            return LlavaImagePixelInputs(
                type="pixel_values",
638
                pixel_values=self._validate_pixel_values(
639
                    flatten_bn(pixel_values, concat=True)),
640
641
642
            )

        if image_embeds is not None:
643
            if not isinstance(image_embeds, (torch.Tensor, list)):
644
645
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
646

647
648
649
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

650
651
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
652
                data=flatten_bn(image_embeds, concat=True),
653
654
655
            )

        raise AssertionError("This line should be unreachable.")
656
657
658
659
660
661
662
663
664
665
666

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

667
668
    def _image_pixels_to_features(
        self,
669
670
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
671
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
672
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
673
674
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
675
        image_features = vision_tower(pixel_values)
676

677
678
679
680
681
682
683
684
685
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
686
687
        )

688
689
690
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
691
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
692
693
        assert self.vision_tower is not None

694
        pixel_values = inputs["pixel_values"]
695
696
697

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

698
699
700
701
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
702
703
704
        if image_input["type"] == "image_embeds":
            return image_input["data"]

705
706
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
707

708
709
710
711
712
713
714
715
716
717
718
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

719
720
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
721
722
723
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
724

725
        vision_embeddings = self._process_image_input(image_input)
726

727
728
        if (kwargs.get("v0_path", False)
                or image_input["type"] != "pixel_values_pixtral"):
729
            # The path is used for pixtral (V0 only) and llava (V0/V1)
730
            return vision_embeddings
731

732
        return flatten_2d_lists(
733
            scatter_patch_features(*args) for args in zip(
734
                vision_embeddings,
735
                image_input["num_embeds"],
736
737
                image_input["embed_is_patch"],
            ))
738
739
740
741

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
742
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
743
744
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
745
        if multimodal_embeddings is not None:
746
            inputs_embeds = merge_multimodal_embeddings(
747
748
                input_ids,
                inputs_embeds,
749
                select_patch_features(multimodal_embeddings),
750
751
                self.config.image_token_index,
            )
752
753
        return inputs_embeds

754
755
756
757
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
758
        intermediate_tensors: Optional[IntermediateTensors] = None,
759
        inputs_embeds: Optional[torch.Tensor] = None,
760
        **kwargs: object,
761
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
762
        """Run forward pass for LLaVA-1.5.
763
764
765

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
766

767
        Concretely, consider a text prompt:
768
769
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

770
        Tokenizer outputs:
771
772
773
774
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
775
        before they are inputted to the model, so the input processor prepends
776
777
778
779
780
781
782
783
784
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
785
786
787
788
789
790
791

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
792
            pixel_values: The pixels in each input image.
793

794
795
        See also:
            :class:`LlavaImageInputs`
796
        """
797
798
        if intermediate_tensors is not None:
            inputs_embeds = None
799
800
801

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
802
        elif inputs_embeds is None:
803
            kwargs.update({"v0_path": True})
804
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
805
806
807
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
808

809
810
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
811
                                                  intermediate_tensors,
812
                                                  inputs_embeds=inputs_embeds)
813
814
815

        return hidden_states

816
817
818
819
820
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
821
822
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
823
824
825
826
827
828

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
829
        return self.language_model.sample(logits, sampling_metadata)
830

831
832
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
833
        loader = AutoWeightsLoader(self)
834
        return loader.load_weights(weights)
835
836


837
838
class MantisProcessingInfo(LlavaProcessingInfo):

839
    def get_hf_processor(self, **kwargs: object):
840
841
842
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

843
844
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

845
846
847
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
848
            kwargs.setdefault("vision_feature_select_strategy", None)
849
850
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
851
852
853
854
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
855

856
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
857
858


859
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
860

861
862
    def apply(
        self,
863
        prompt: Union[str, list[int]],
864
865
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
866
        return_mm_hashes: bool = False,
867
    ) -> MultiModalInputs:
868
        hf_config = self.info.get_hf_config()
869
        image_token_id = hf_config.image_token_index
870
871

        # Assume that it doesn't depend on the image size
872
        num_image_tokens = self.info.get_num_image_tokens(
873
874
875
            image_width=-1,
            image_height=-1,
        )
876

877
878
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
879

880
881
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
882
883
884
885
886
887
888
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
889
                "<image>" * num_image_tokens,
890
891
892
                "</Image>)",  # 3 tokens
            ])

893
        mantis_mm_repls = self._bind_and_group_updates([
894
895
            PromptReplacement(
                modality="image",
896
                target=[image_token_id] * num_image_tokens,
897
898
899
900
                replacement=get_replacement_mantis,
            )
        ])

901
        prompt_ids, prompt, _ = self._apply_prompt_updates(
902
            result["prompt_token_ids"],
903
            mantis_mm_repls,
904
905
906
            mm_item_counts,
        )

907
        unbound_orig_repls = self._get_prompt_updates(
908
909
910
911
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
912
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
913
914
915
916
917
918
919

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
920

921
922
923
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
924
925
        }

926
        return MultiModalInputs(
927
            type="multimodal",
928
            prompt=prompt,
929
930
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
931
            mm_placeholders=mm_placeholder_ranges,
932
        )
933
934
935
936


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
937
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
938
                                        info=MantisProcessingInfo,
939
                                        dummy_inputs=LlavaDummyInputsBuilder)
940
941
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass