llava.py 37.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
7
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
                    TypeVar, Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.jsontree import JSONTree, json_map_leaves
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
26
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
30
                                    MultiModalInputs, MultiModalKwargs,
31
                                    NestedTensors)
32
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
33
                                   ImageSize, MultiModalDataItems)
34
from vllm.multimodal.processing import (BaseMultiModalProcessor,
35
                                        BaseProcessingInfo, ProcessingCache,
36
                                        PromptReplacement, PromptUpdate)
37
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
38
from vllm.sequence import IntermediateTensors
39
from vllm.utils import flatten_2d_lists
40

41
from .clip import CLIPVisionModel
42
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
43
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
44
from .siglip import SiglipVisionModel
45
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
46
                    maybe_prefix, merge_multimodal_embeddings)
47
from .vision import get_vision_encoder_info
48
49


50
51
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
52
    pixel_values: torch.Tensor
53
54
55
56
57
58
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
59

60
61
62
63
64
65
66
67
68
69
70
71

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
72
73
74
75
76
77
78
    """
    A boolean mask indicating which image features correspond
    to patch tokens.

    Shape: `(batch_size, num_crops, num_patch)`
    """

79
    embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
80
81
82
83
84
85
86
    """
    A boolean mask indicating which image embeddings correspond
    to patch tokens.
    
    Shape: `(batch_size, num_embeds)`
    """

87
    num_crops: Union[torch.Tensor, list[torch.Tensor]]
88
89
    """Shape: `(batch_size, num_images)`"""

90
91
92
93

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
94
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
95
96
97
98
99

    `hidden_size` must match the hidden size of language model backbone.
    """


100
101
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
102
103


104
105
class LlavaMultiModalProjector(nn.Module):

106
107
108
109
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
110
                 multimodal_projector_bias: bool,
111
112
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
113
114
        super().__init__()

115
116
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
117
                                             bias=multimodal_projector_bias,
118
119
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
120
        self.act = get_act_fn(projector_hidden_act)
121
122
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
123
                                          bias=multimodal_projector_bias,
124
125
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
126

127
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
128
        hidden_states, _ = self.linear_1(image_features)
129
        hidden_states = self.act(hidden_states)
130
        hidden_states, _ = self.linear_2(hidden_states)
131
132
133
        return hidden_states


134
135
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
136
    image_token_index: Final[int]
137
    vision_feature_select_strategy: Final[str]
138
    vision_feature_layer: Final[Union[int, list[int]]]
139

140

141
142
143
144
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


145
class BaseLlavaProcessingInfo(BaseProcessingInfo):
146

147
    def get_hf_config(self) -> LlavaLikeConfig:
148
        return self.ctx.get_hf_config(LlavaConfig)
149

150
151
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
152

153
    @abstractmethod
154
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
155
        raise NotImplementedError
156

157
158
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
159

160
161
162
163
164
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
165
        return {"image": self.get_max_image_tokens()}
166

167
168
169
170
171
172
173
174
175
176
177
178
179
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

180
181
182
183
184
185
186
187
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
188

189
190
191
192
193
194
195
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
196

197
198
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
199
200
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
201

202
203
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
204

205
        return self.get_num_image_tokens(
206
207
208
209
            image_width=target_width,
            image_height=target_height,
        )

210
211
212
213
214
215

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

216
    def get_dummy_processor_inputs(
217
        self,
218
        seq_len: int,
219
220
221
222
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

223
        processor = self.info.get_hf_processor()
224
        image_token = processor.image_token
225
226
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
227
228
229
230
231
232
233
234
235
236
237
238
239
240

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


241
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
242

243
244
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
245
246


247
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
248
249
250
251
252
253
254
255
256

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
257

258
    def _get_prompt_updates(
259
260
261
262
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
263
    ) -> Sequence[PromptUpdate]:
264
        hf_config = self.info.get_hf_config()
265
266
267
268
269
270
271
272
273
274
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
275
                num_image_tokens = self.info.get_num_image_tokens(
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


291
292
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
293

294
295
296
297
298
299
300
301
302
303
304
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


305
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
306

307
308
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
309

310

311
312
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
313

314
315
316
317
318
319
320
321
322
323
324
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
325

326
327
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
328
            # Before/after https://github.com/huggingface/transformers/pull/35122
329
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
351

352
            hf_config = self.info.get_hf_config()
353
354
355
            vision_config = hf_config.vision_config
            assert isinstance(vision_config, PixtralVisionConfig)
            encoder_info = PixtralHFEncoderInfo(vision_config)
356
357

            tile_sizes = [
358
                encoder_info.get_patch_grid_size(
359
                    image_width=pixel_value.shape[-1],
360
361
                    image_height=pixel_value.shape[-2],
                ) for pixel_value in processed_outputs["pixel_values"]
362
363
364
365
366
367
368
369
370
371
372
373
            ]
            num_crops = torch.tensor([(ncols + 1) * nrows
                                      for ncols, nrows in tile_sizes])
            # Each image may result to masks of different sizes, so we need to
            # flatten the list and later use `num_crops` to get per-image masks.
            embed_is_patch = torch.tensor(
                flatten_2d_lists([([True] * ncols + [False]) * nrows
                                  for ncols, nrows in tile_sizes]))
            processed_outputs["num_crops"] = num_crops
            processed_outputs["embed_is_patch"] = embed_is_patch
            processed_outputs["feat_is_patch"] = embed_is_patch

374
        return processed_outputs
375

376
377
378
379
380
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
381
        num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
382
        return dict(
383
384
385
386
387
            feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
                "image", num_crops),
            embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
                "image", num_crops),
            num_crops=MultiModalFieldConfig.batched("image"),
388
389
390
391
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

392
    def _get_prompt_updates(
393
394
        self,
        mm_items: MultiModalDataItems,
395
396
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
397
    ) -> Sequence[PromptUpdate]:
398
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
399
        hf_config = self.info.get_hf_config()
400
401
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
402

403
404
405
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
406

407
408
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
409
        encoder_info = PixtralHFEncoderInfo(vision_config)
410

411
412
413
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
414

415
            ncols, nrows = encoder_info.get_patch_grid_size(
416
417
418
                image_width=image_size.width,
                image_height=image_size.height,
            )
419

420
421
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
422

423
            return tokens
424
425
426
427
428

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
429
430
                replacement=get_replacement,
            ),
431
432
        ]

433

434
435
436
437
438
439
440
441
442
443
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


444
def _build_llava_or_pixtral_hf_processor(
445
446
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
447
448
449
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
450
) -> BaseMultiModalProcessor:
451
    if isinstance(info, PixtralHFProcessingInfo):
452
        return PixtralHFMultiModalProcessor(
453
454
455
456
457
458
459
460
461
462
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
463
464
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
465
        )
466

467
    raise NotImplementedError(type(info))
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
491
    """Given a signed vision feature layer, get the number of hidden layers
492
493
494
495
496
497
498
499
500
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
501
    return feature_layer_index
502
503
504
505
506
507
508


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
509
    prefix: str = "",
510
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
511
512
    vision_config = hf_config.vision_config

513
514
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
515
516
517
518

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
519
            quant_config=quant_config,
520
            num_hidden_layers_override=num_hidden_layers,
521
            require_post_norm=require_post_norm,
522
            prefix=prefix,
523
524
525
526
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
527
            quant_config=quant_config,
528
            num_hidden_layers_override=num_hidden_layers,
529
            require_post_norm=require_post_norm,
530
            prefix=prefix,
531
        )
532
    elif isinstance(vision_config, PixtralVisionConfig):
533
534
        return PixtralHFVisionModel(
            vision_config,
535
            quant_config=quant_config,
536
537
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
538
            prefix=prefix,
539
        )
540
541
542
543
544

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


545
546
547
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
548
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
549
550
551
552

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
553
    }
554

555
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
556
        super().__init__()
557

558
559
560
561
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

562
        self.config = config
563
        self.multimodal_config = multimodal_config
564

565
566
567
568
569
570
571
572
573
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

574
        # TODO: Optionally initializes this for supporting embeddings.
575
        self.vision_tower = init_vision_tower_for_llava(
576
577
578
            config,
            quant_config,
            require_post_norm=False,
579
            prefix=maybe_prefix(prefix, "vision_tower"))
580
581
582
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
583
            projector_hidden_act=config.projector_hidden_act,
584
            multimodal_projector_bias=config.multimodal_projector_bias,
585
586
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
587

588
        self.language_model = init_vllm_registered_model(
589
            vllm_config=vllm_config,
590
591
592
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
593

594
595
596
597
598
599
600
601
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
602
        return get_sampler()
603

604
605
606
607
608
609
610
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
611
            raise ValueError(
612
613
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
614
615
616
617

        return data

    def _parse_and_validate_image_input(
618
619
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
620
        image_embeds = kwargs.pop("image_embeds", None)
621

622
        if pixel_values is None and image_embeds is None:
623
            return None
624

625
        if pixel_values is not None:
626
            if not isinstance(pixel_values, (torch.Tensor, list)):
627
628
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
629

630
            if self.config.vision_config.model_type == "pixtral":
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
                feat_is_patch = kwargs.pop("feat_is_patch")
                if not isinstance(feat_is_patch, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of feat_is_patch. "
                                     f"Got type: {type(feat_is_patch)}")

                embed_is_patch = kwargs.pop("embed_is_patch")
                if not isinstance(embed_is_patch, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of embed_is_patch. "
                                     f"Got type: {type(embed_is_patch)}")

                num_crops = kwargs.pop("num_crops")
                if not isinstance(num_crops, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of num_crops. "
                                     f"Got type: {type(num_crops)}")

                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
649
650
651
                    feat_is_patch=feat_is_patch,
                    embed_is_patch=embed_is_patch,
                    num_crops=num_crops,
652
653
                )

654
655
            return LlavaImagePixelInputs(
                type="pixel_values",
656
                pixel_values=self._validate_pixel_values(
657
                    flatten_bn(pixel_values, concat=True)),
658
659
660
            )

        if image_embeds is not None:
661
            if not isinstance(image_embeds, (torch.Tensor, list)):
662
663
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
664

665
666
667
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

668
669
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
670
                data=flatten_bn(image_embeds, concat=True),
671
672
673
            )

        raise AssertionError("This line should be unreachable.")
674
675
676
677
678
679
680
681
682
683
684

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

685
686
    def _image_pixels_to_features(
        self,
687
688
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
689
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
690
    ) -> torch.Tensor:
691

692
693
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
694
        image_features = vision_tower(pixel_values)
695
696
697
698
699
700

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

701
702
703
704
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
    ) -> torch.Tensor:
705
706
        assert self.vision_tower is not None

707
        pixel_values = inputs["pixel_values"]
708
709
710

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

711
712
713
714
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
715
716
717
        if image_input["type"] == "image_embeds":
            return image_input["data"]

718
719
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
720

721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

    def _get_mm_embeds(
            self,
            features: torch.Tensor,  # Shape: (num_crop, num_patch, d)
            feat_is_patch: torch.Tensor,  # Shape: (num_crop, num_patch)
            num_crops: torch.Tensor,  # Shape: (num_images,)
            embed_is_patch: torch.Tensor,  # Shape: (num_embeds,)
    ) -> list[torch.Tensor]:
        """Scatter the patch features into a contiguous tensor that corresponds
        to the embedding tokens defined by the multimodal processor.

        Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
        """

        # Insert columns of nan values according to `feat_is_patch`. This work
        # ideally should be done in `_process_image_input`, but
        # `_process_image_input` is used in both V0 and V1 path. It's safer to
        # put the logic here.
        # FIXME: Move this logic to `_process_image_input` when v0 is
        # deprecated. Merge this function with `Molmo._get_mm_embeds`.
        feat_is_patch = feat_is_patch.view(-1)
        embed_is_patch = embed_is_patch.view(-1)
        expanded_embedding = torch.full(
            (sum(num_crops), *features.shape[1:]),
            torch.nan,
            dtype=features.dtype).to(features.device)
        expanded_embedding[feat_is_patch] = features

        num_crops_per_image = num_crops.tolist()
        feats_per_image = expanded_embedding.split(num_crops_per_image)
        f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)

        embed_dim = expanded_embedding.shape[-1]
        num_embeds = embed_is_patch.shape[0]

        embeds_in_batch = list[torch.Tensor]()
        for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
            embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
            embeds[embed_is_patch] = feats[f_is_patch]
            embeds_in_batch.append(embeds)

        return embeds_in_batch

774
775
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
776
777
778
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
779

780
        vision_embeddings = self._process_image_input(image_input)
781

782
783
        if (kwargs.get("v0_path", False)
                or image_input["type"] != "pixel_values_pixtral"):
784
            # The path is used for pixtral (V0 only) and llava (V0/V1)
785
            return vision_embeddings
786
787
788
789
790
791
792

        nested_emb = [
            self._get_mm_embeds(*args) for args in zip(
                vision_embeddings, image_input["feat_is_patch"],
                image_input["num_crops"], image_input["embed_is_patch"])
        ]
        return flatten_2d_lists(nested_emb)
793
794
795
796

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
797
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
798
799
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
800
        if multimodal_embeddings is not None:
801
802
803
804
805
806
            # Extract the patch tokens
            patch_embeddings = json_map_leaves(
                lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
                cast(JSONTree[torch.Tensor], multimodal_embeddings),
            )

807
            inputs_embeds = merge_multimodal_embeddings(
808
809
                input_ids, inputs_embeds, cast(NestedTensors,
                                               patch_embeddings),
810
811
812
                self.config.image_token_index)
        return inputs_embeds

813
814
815
816
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
817
        intermediate_tensors: Optional[IntermediateTensors] = None,
818
        inputs_embeds: Optional[torch.Tensor] = None,
819
        **kwargs: object,
820
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
821
        """Run forward pass for LLaVA-1.5.
822
823
824

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
825

826
        Concretely, consider a text prompt:
827
828
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

829
        Tokenizer outputs:
830
831
832
833
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
834
        before they are inputted to the model, so the input processor prepends
835
836
837
838
839
840
841
842
843
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
844
845
846
847
848
849
850

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
851
            pixel_values: The pixels in each input image.
852

853
854
        See also:
            :class:`LlavaImageInputs`
855
        """
856
857
        if intermediate_tensors is not None:
            inputs_embeds = None
858
859
860

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
861
        elif inputs_embeds is None:
862
            kwargs.update({"v0_path": True})
863
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
864
865
866
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
867

868
869
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
870
                                                  intermediate_tensors,
871
                                                  inputs_embeds=inputs_embeds)
872
873
874

        return hidden_states

875
876
877
878
879
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
880
881
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
882
883
884
885
886
887

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
888
        return self.language_model.sample(logits, sampling_metadata)
889

890
891
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
892
        loader = AutoWeightsLoader(self)
893
        return loader.load_weights(weights)
894
895


896
897
class MantisProcessingInfo(LlavaProcessingInfo):

898
    def get_hf_processor(self, **kwargs: object):
899
900
901
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

902
903
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

904
905
906
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
907
            kwargs.setdefault("vision_feature_select_strategy", None)
908
909
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
910
911
912
913
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
914

915
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
916
917


918
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
919

920
921
    def apply(
        self,
922
        prompt: Union[str, list[int]],
923
924
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
925
        return_mm_hashes: bool = False,
926
    ) -> MultiModalInputs:
927
        hf_config = self.info.get_hf_config()
928
        image_token_id = hf_config.image_token_index
929
930

        # Assume that it doesn't depend on the image size
931
        num_image_tokens = self.info.get_num_image_tokens(
932
933
934
            image_width=-1,
            image_height=-1,
        )
935

936
937
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
938

939
940
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
941
942
943
944
945
946
947
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
948
                "<image>" * num_image_tokens,
949
950
951
                "</Image>)",  # 3 tokens
            ])

952
        mantis_mm_repls = self._bind_and_group_updates([
953
954
            PromptReplacement(
                modality="image",
955
                target=[image_token_id] * num_image_tokens,
956
957
958
959
                replacement=get_replacement_mantis,
            )
        ])

960
        prompt_ids, prompt, _ = self._apply_prompt_updates(
961
            result["prompt_token_ids"],
962
            mantis_mm_repls,
963
964
965
            mm_item_counts,
        )

966
        unbound_orig_repls = self._get_prompt_updates(
967
968
969
970
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
971
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
972
973
974
975
976
977
978

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
979

980
981
982
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
983
984
        }

985
        return MultiModalInputs(
986
            type="multimodal",
987
            prompt=prompt,
988
989
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
990
            mm_placeholders=mm_placeholder_ranges,
991
        )
992
993
994
995


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
996
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
997
                                        info=MantisProcessingInfo,
998
                                        dummy_inputs=LlavaDummyInputsBuilder)
999
1000
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass