llava.py 34.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
7
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
                    TypeVar, Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.jsontree import json_map_leaves
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
26
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
30
                                    MultiModalInputs, MultiModalKwargs)
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
                                   ImageSize, MultiModalDataItems)
33
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
                                        BaseProcessingInfo, ProcessingCache,
35
                                        PromptReplacement, PromptUpdate)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
from vllm.sequence import IntermediateTensors
38
from vllm.utils import flatten_2d_lists
39

40
from .clip import CLIPVisionModel
41
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
42
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
43
from .siglip import SiglipVisionModel
44
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
45
                    maybe_prefix, merge_multimodal_embeddings)
46
47
from .vision import (get_vision_encoder_info, scatter_patch_features,
                     select_patch_features)
48
49


50
51
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
52
    pixel_values: torch.Tensor
53
54
55
56
57
58
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
59

60
61
62
63
64
65
66
67
68
69
70
71

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
72
73
74
75
    """
    A boolean mask indicating which image embeddings correspond
    to patch tokens.
    
76
    Shape: `(batch_size, num_images, num_embeds)`
77
78
    """

79
    num_embeds: Union[torch.Tensor, list[torch.Tensor]]
80
81
    """Shape: `(batch_size, num_images)`"""

82
83
84
85

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
86
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
87
88
89
90
91

    `hidden_size` must match the hidden size of language model backbone.
    """


92
93
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
94
95


96
97
class LlavaMultiModalProjector(nn.Module):

98
99
100
101
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
102
                 multimodal_projector_bias: bool,
103
104
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
105
106
        super().__init__()

107
108
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
109
                                             bias=multimodal_projector_bias,
110
111
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
112
        self.act = get_act_fn(projector_hidden_act)
113
114
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
115
                                          bias=multimodal_projector_bias,
116
117
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
118

119
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
120
        hidden_states, _ = self.linear_1(image_features)
121
        hidden_states = self.act(hidden_states)
122
        hidden_states, _ = self.linear_2(hidden_states)
123
124
125
        return hidden_states


126
127
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
128
    image_token_index: Final[int]
129
    vision_feature_select_strategy: Final[str]
130
    vision_feature_layer: Final[Union[int, list[int]]]
131

132

133
134
135
136
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


137
class BaseLlavaProcessingInfo(BaseProcessingInfo):
138

139
    def get_hf_config(self) -> LlavaLikeConfig:
140
        return self.ctx.get_hf_config(LlavaConfig)
141

142
143
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
144

145
    @abstractmethod
146
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
147
        raise NotImplementedError
148

149
150
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
151

152
153
154
155
156
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
157
        return {"image": self.get_max_image_tokens()}
158

159
160
161
162
163
164
165
166
167
168
169
170
171
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

172
173
174
175
176
177
178
179
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
180

181
182
183
184
185
186
187
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
188

189
190
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
191
192
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
193

194
195
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
196

197
        return self.get_num_image_tokens(
198
199
200
201
            image_width=target_width,
            image_height=target_height,
        )

202
203
204
205
206
207

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

208
    def get_dummy_processor_inputs(
209
        self,
210
        seq_len: int,
211
212
213
214
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

215
        processor = self.info.get_hf_processor()
216
        image_token = processor.image_token
217
218
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
219
220
221
222
223
224
225
226
227
228
229
230
231
232

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


233
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
234

235
    def get_hf_processor(self, **kwargs: object):
236
237
238
239
240
241
242
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
243
244


245
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
246
247
248
249
250
251
252
253
254

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
255

256
    def _get_prompt_updates(
257
258
259
260
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
261
    ) -> Sequence[PromptUpdate]:
262
        hf_config = self.info.get_hf_config()
263
264
265
266
267
268
269
270
271
272
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
273
                num_image_tokens = self.info.get_num_image_tokens(
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


289
290
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
291

292
293
294
295
296
297
298
299
300
301
302
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


303
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
304

305
306
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
307

308

309
310
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
311

312
313
314
315
316
317
318
319
320
321
322
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
323

324
325
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
326
            # Before/after https://github.com/huggingface/transformers/pull/35122
327
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
349

350
            hf_config = self.info.get_hf_config()
351
352
353
            vision_config = hf_config.vision_config
            assert isinstance(vision_config, PixtralVisionConfig)
            encoder_info = PixtralHFEncoderInfo(vision_config)
354
355

            tile_sizes = [
356
                encoder_info.get_patch_grid_size(
357
                    image_width=pixel_value.shape[-1],
358
359
                    image_height=pixel_value.shape[-2],
                ) for pixel_value in processed_outputs["pixel_values"]
360
            ]
361
362
            num_embeds = torch.tensor([(ncols + 1) * nrows
                                       for ncols, nrows in tile_sizes])
363
            # Each image may result to masks of different sizes, so we need to
364
            # later use `num_embeds` to get per-image masks.
365
366
367
368
            embed_is_patch = [
                torch.tensor(([True] * ncols + [False]) * nrows)
                for ncols, nrows in tile_sizes
            ]
369
            processed_outputs["num_embeds"] = num_embeds
370
371
            processed_outputs["embed_is_patch"] = embed_is_patch

372
        return processed_outputs
373

374
375
376
377
378
379
380
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
381
            num_embeds=MultiModalFieldConfig.batched("image"),
382
            embed_is_patch=MultiModalFieldConfig.batched("image"),
383
384
385
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

386
    def _get_prompt_updates(
387
388
        self,
        mm_items: MultiModalDataItems,
389
390
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
391
    ) -> Sequence[PromptUpdate]:
392
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
393
        hf_config = self.info.get_hf_config()
394
395
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
396

397
398
399
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
400

401
402
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
403
        encoder_info = PixtralHFEncoderInfo(vision_config)
404

405
406
407
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
408

409
            ncols, nrows = encoder_info.get_patch_grid_size(
410
411
412
                image_width=image_size.width,
                image_height=image_size.height,
            )
413

414
415
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
416

417
            return tokens
418
419
420
421
422

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
423
424
                replacement=get_replacement,
            ),
425
426
        ]

427

428
429
430
431
432
433
434
435
436
437
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


438
def _build_llava_or_pixtral_hf_processor(
439
440
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
441
442
443
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
444
) -> BaseMultiModalProcessor:
445
    if isinstance(info, PixtralHFProcessingInfo):
446
        return PixtralHFMultiModalProcessor(
447
448
449
450
451
452
453
454
455
456
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
457
458
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
459
        )
460

461
    raise NotImplementedError(type(info))
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
485
    """Given a signed vision feature layer, get the number of hidden layers
486
487
488
489
490
491
492
493
494
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
495
    return feature_layer_index
496
497
498
499
500
501
502


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
503
    prefix: str = "",
504
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
505
506
    vision_config = hf_config.vision_config

507
508
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
509
510
511
512

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
513
            quant_config=quant_config,
514
            num_hidden_layers_override=num_hidden_layers,
515
            require_post_norm=require_post_norm,
516
            prefix=prefix,
517
518
519
520
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
521
            quant_config=quant_config,
522
            num_hidden_layers_override=num_hidden_layers,
523
            require_post_norm=require_post_norm,
524
            prefix=prefix,
525
        )
526
    elif isinstance(vision_config, PixtralVisionConfig):
527
528
        return PixtralHFVisionModel(
            vision_config,
529
            quant_config=quant_config,
530
531
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
532
            prefix=prefix,
533
        )
534
535
536
537
538

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


539
540
541
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
542
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
543
544
545
546

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
547
    }
548

549
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
550
        super().__init__()
551

552
553
554
555
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

556
        self.config = config
557
        self.multimodal_config = multimodal_config
558

559
560
561
562
563
564
565
566
567
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

568
        # TODO: Optionally initializes this for supporting embeddings.
569
        self.vision_tower = init_vision_tower_for_llava(
570
571
572
            config,
            quant_config,
            require_post_norm=False,
573
            prefix=maybe_prefix(prefix, "vision_tower"))
574
575
576
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
577
            projector_hidden_act=config.projector_hidden_act,
578
            multimodal_projector_bias=config.multimodal_projector_bias,
579
580
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
581

582
        self.language_model = init_vllm_registered_model(
583
            vllm_config=vllm_config,
584
585
586
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
587

588
589
590
591
592
593
594
595
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
596
        return get_sampler()
597

598
599
600
601
602
603
604
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
605
            raise ValueError(
606
607
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
608
609
610
611

        return data

    def _parse_and_validate_image_input(
612
613
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
614
        image_embeds = kwargs.pop("image_embeds", None)
615

616
        if pixel_values is None and image_embeds is None:
617
            return None
618

619
        if pixel_values is not None:
620
            if not isinstance(pixel_values, (torch.Tensor, list)):
621
622
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
623

624
            if self.config.vision_config.model_type == "pixtral":
625
626
627
628
629
                embed_is_patch = kwargs.pop("embed_is_patch")
                if not isinstance(embed_is_patch, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of embed_is_patch. "
                                     f"Got type: {type(embed_is_patch)}")

630
631
632
633
                num_embeds = kwargs.pop("num_embeds")
                if not isinstance(num_embeds, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of num_embeds. "
                                     f"Got type: {type(num_embeds)}")
634
635
636
637

                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
638
                    embed_is_patch=embed_is_patch,
639
                    num_embeds=num_embeds,
640
641
                )

642
643
            return LlavaImagePixelInputs(
                type="pixel_values",
644
                pixel_values=self._validate_pixel_values(
645
                    flatten_bn(pixel_values, concat=True)),
646
647
648
            )

        if image_embeds is not None:
649
            if not isinstance(image_embeds, (torch.Tensor, list)):
650
651
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
652

653
654
655
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

656
657
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
658
                data=flatten_bn(image_embeds, concat=True),
659
660
661
            )

        raise AssertionError("This line should be unreachable.")
662
663
664
665
666
667
668
669
670
671
672

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

673
674
    def _image_pixels_to_features(
        self,
675
676
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
677
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
678
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
679
680
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
681
        image_features = vision_tower(pixel_values)
682

683
684
685
686
687
688
689
690
691
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
692
693
        )

694
695
696
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
697
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
698
699
        assert self.vision_tower is not None

700
        pixel_values = inputs["pixel_values"]
701
702
703

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

704
705
706
707
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
708
709
710
        if image_input["type"] == "image_embeds":
            return image_input["data"]

711
712
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
713

714
715
716
717
718
719
720
721
722
723
724
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

725
726
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
727
728
729
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
730

731
        vision_embeddings = self._process_image_input(image_input)
732

733
734
        if (kwargs.get("v0_path", False)
                or image_input["type"] != "pixel_values_pixtral"):
735
            # The path is used for pixtral (V0 only) and llava (V0/V1)
736
            return vision_embeddings
737

738
        return flatten_2d_lists(
739
            scatter_patch_features(*args) for args in zip(
740
                vision_embeddings,
741
                image_input["num_embeds"],
742
743
                image_input["embed_is_patch"],
            ))
744
745
746
747

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
748
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
749
750
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
751
        if multimodal_embeddings is not None:
752
            inputs_embeds = merge_multimodal_embeddings(
753
754
                input_ids,
                inputs_embeds,
755
                select_patch_features(multimodal_embeddings),
756
757
                self.config.image_token_index,
            )
758
759
        return inputs_embeds

760
761
762
763
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
764
        intermediate_tensors: Optional[IntermediateTensors] = None,
765
        inputs_embeds: Optional[torch.Tensor] = None,
766
        **kwargs: object,
767
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
768
        """Run forward pass for LLaVA-1.5.
769
770
771

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
772

773
        Concretely, consider a text prompt:
774
775
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

776
        Tokenizer outputs:
777
778
779
780
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
781
        before they are inputted to the model, so the input processor prepends
782
783
784
785
786
787
788
789
790
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
791
792
793
794
795
796
797

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
798
            pixel_values: The pixels in each input image.
799

800
801
        See also:
            :class:`LlavaImageInputs`
802
        """
803
804
        if intermediate_tensors is not None:
            inputs_embeds = None
805
806
807

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
808
        elif inputs_embeds is None:
809
            kwargs.update({"v0_path": True})
810
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
811
812
813
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
814

815
816
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
817
                                                  intermediate_tensors,
818
                                                  inputs_embeds=inputs_embeds)
819
820
821

        return hidden_states

822
823
824
825
826
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
827
828
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
829
830
831
832
833
834

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
835
        return self.language_model.sample(logits, sampling_metadata)
836

837
838
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
839
        loader = AutoWeightsLoader(self)
840
        return loader.load_weights(weights)
841
842


843
844
class MantisProcessingInfo(LlavaProcessingInfo):

845
    def get_hf_processor(self, **kwargs: object):
846
847
848
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

849
850
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

851
852
853
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
854
            kwargs.setdefault("vision_feature_select_strategy", None)
855
856
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
857
858
859
860
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
861

862
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
863
864


865
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
866

867
868
    def apply(
        self,
869
        prompt: Union[str, list[int]],
870
871
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
872
        return_mm_hashes: bool = False,
873
    ) -> MultiModalInputs:
874
        hf_config = self.info.get_hf_config()
875
        image_token_id = hf_config.image_token_index
876
877

        # Assume that it doesn't depend on the image size
878
        num_image_tokens = self.info.get_num_image_tokens(
879
880
881
            image_width=-1,
            image_height=-1,
        )
882

883
884
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
885

886
887
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
888
889
890
891
892
893
894
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
895
                "<image>" * num_image_tokens,
896
897
898
                "</Image>)",  # 3 tokens
            ])

899
        mantis_mm_repls = self._bind_and_group_updates([
900
901
            PromptReplacement(
                modality="image",
902
                target=[image_token_id] * num_image_tokens,
903
904
905
906
                replacement=get_replacement_mantis,
            )
        ])

907
        prompt_ids, prompt, _ = self._apply_prompt_updates(
908
            result["prompt_token_ids"],
909
            mantis_mm_repls,
910
911
912
            mm_item_counts,
        )

913
        unbound_orig_repls = self._get_prompt_updates(
914
915
916
917
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
918
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
919
920
921
922
923
924
925

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
926

927
928
929
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
930
931
        }

932
        return MultiModalInputs(
933
            type="multimodal",
934
            prompt=prompt,
935
936
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
937
            mm_placeholders=mm_placeholder_ranges,
938
        )
939
940
941
942


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
943
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
944
                                        info=MantisProcessingInfo,
945
                                        dummy_inputs=LlavaDummyInputsBuilder)
946
947
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass