llava.py 36.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
7
                    TypedDict, TypeVar, Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.model_executor.layers.activation import get_act_fn
22
23
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
25
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
29
                                    MultiModalInputs, MultiModalKwargs,
30
                                    NestedTensors)
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
                                   ImageSize, MultiModalDataItems)
33
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
                                        BaseProcessingInfo, ProcessingCache,
35
                                        PromptReplacement, PromptUpdate)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
from vllm.sequence import IntermediateTensors
38
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
39

40
from .clip import CLIPVisionModel
41
from .interfaces import SupportsMultiModal, SupportsPP
42
43
44
from .pixtral import (PixtralHFVisionModel,
                      get_pixtral_hf_image_feature_grid_size)
from .siglip import SiglipVisionModel
45
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
46
                    maybe_prefix, merge_multimodal_embeddings)
47
from .vision import get_vision_encoder_info
48
49


50
51
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
52
53
54
55
56
57
58
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
    """
    A boolean mask indicating which image features correspond
    to patch tokens.

    Shape: `(batch_size, num_crops, num_patch)`
    """

    embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
    """
    A boolean mask indicating which image embeddings correspond
    to patch tokens.
    
    Shape: `(batch_size, num_embeds)`
    """

    num_crops: torch.Tensor
    """Shape: `(batch_size, num_images)`"""

79
80
81
82

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
83
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
84
85
86
87

    `hidden_size` must match the hidden size of language model backbone.
    """

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
    """
    A boolean mask indicating which image features correspond
    to patch tokens.

    Shape: `(batch_size, num_crops, num_patch)`
    """

    embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
    """
    A boolean mask indicating which image embeddings correspond
    to patch tokens.
    
    Shape: `(batch_size, num_embeds)`
    """

    num_crops: torch.Tensor
    """Shape: `(batch_size, num_images)`"""

107
108
109
110

LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


111
112
class LlavaMultiModalProjector(nn.Module):

113
114
115
116
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
117
                 multimodal_projector_bias: bool,
118
119
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
120
121
        super().__init__()

122
123
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
124
                                             bias=multimodal_projector_bias,
125
126
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
127
        self.act = get_act_fn(projector_hidden_act)
128
129
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
130
                                          bias=multimodal_projector_bias,
131
132
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
133

134
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
135
        hidden_states, _ = self.linear_1(image_features)
136
        hidden_states = self.act(hidden_states)
137
        hidden_states, _ = self.linear_2(hidden_states)
138
139
140
        return hidden_states


141
142
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
143
    image_token_index: Final[int]
144
    vision_feature_select_strategy: Final[str]
145
    vision_feature_layer: Final[Union[int, list[int]]]
146

147

148
149
150
151
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


152
class BaseLlavaProcessingInfo(BaseProcessingInfo):
153

154
    def get_hf_config(self) -> LlavaLikeConfig:
155
        return self.ctx.get_hf_config(LlavaConfig)
156

157
158
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
159

160
    @abstractmethod
161
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
162
        raise NotImplementedError
163

164
165
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
166

167
168
169
170
171
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
172
        return {"image": self.get_max_image_tokens()}
173

174
175
176
177
178
179
180
181
182
183
184
185
186
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

187
188
189
190
191
192
193
194
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
195

196
197
198
199
200
201
202
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
203

204
205
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
206
207
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
208

209
210
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
211

212
        return self.get_num_image_tokens(
213
214
215
216
            image_width=target_width,
            image_height=target_height,
        )

217
218
219
220
221
222

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

223
    def get_dummy_processor_inputs(
224
        self,
225
        seq_len: int,
226
227
228
229
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

230
        processor = self.info.get_hf_processor()
231
        image_token = processor.image_token
232
233
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
234
235
236
237
238
239
240
241
242
243
244
245
246
247

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


248
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
249

250
251
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
252
253


254
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
255
256
257
258
259
260
261
262
263

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
264

265
    def _get_prompt_updates(
266
267
268
269
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
270
    ) -> Sequence[PromptUpdate]:
271
        hf_config = self.info.get_hf_config()
272
273
274
275
276
277
278
279
280
281
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
282
                num_image_tokens = self.info.get_num_image_tokens(
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


298
299
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
300

301
302
303
304
305
306
307
308
309
310
311
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


312
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
313

314
315
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
316

317

318
319
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
320

321
322
323
324
325
326
327
328
329
330
331
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
332

333
334
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
335
            # Before/after https://github.com/huggingface/transformers/pull/35122
336
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
358

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
            hf_config = self.info.get_hf_config()

            tile_sizes = [
                get_pixtral_hf_image_feature_grid_size(
                    hf_config.vision_config,
                    image_width=pixel_value.shape[-1],
                    image_height=pixel_value.shape[-2])
                for pixel_value in processed_outputs["pixel_values"]
            ]
            num_crops = torch.tensor([(ncols + 1) * nrows
                                      for ncols, nrows in tile_sizes])
            # Each image may result to masks of different sizes, so we need to
            # flatten the list and later use `num_crops` to get per-image masks.
            embed_is_patch = torch.tensor(
                flatten_2d_lists([([True] * ncols + [False]) * nrows
                                  for ncols, nrows in tile_sizes]))
            processed_outputs["num_crops"] = num_crops
            processed_outputs["embed_is_patch"] = embed_is_patch
            processed_outputs["feat_is_patch"] = embed_is_patch

379
        return processed_outputs
380

381
382
383
384
385
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
386
        num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
387
        return dict(
388
389
390
391
392
            feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
                "image", num_crops),
            embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
                "image", num_crops),
            num_crops=MultiModalFieldConfig.batched("image"),
393
394
395
396
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

397
    def _get_prompt_updates(
398
399
        self,
        mm_items: MultiModalDataItems,
400
401
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
402
    ) -> Sequence[PromptUpdate]:
403
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
404
        hf_config = self.info.get_hf_config()
405
406
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
407

408
409
410
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
411

412
413
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
414

415
416
417
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
418

419
420
421
422
423
            ncols, nrows = get_pixtral_hf_image_feature_grid_size(
                vision_config,
                image_width=image_size.width,
                image_height=image_size.height,
            )
424

425
426
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
427

428
            return tokens
429
430
431
432
433

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
434
435
                replacement=get_replacement,
            ),
436
437
        ]

438

439
440
441
442
443
444
445
446
447
448
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


449
def _build_llava_or_pixtral_hf_processor(
450
451
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
452
453
454
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
455
) -> BaseMultiModalProcessor:
456
    if isinstance(info, PixtralHFProcessingInfo):
457
        return PixtralHFMultiModalProcessor(
458
459
460
461
462
463
464
465
466
467
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
468
469
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
470
        )
471

472
    raise NotImplementedError(type(info))
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
496
    """Given a signed vision feature layer, get the number of hidden layers
497
498
499
500
501
502
503
504
505
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
506
    return feature_layer_index
507
508
509
510
511
512
513


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
514
    prefix: str = "",
515
):
516
517
    vision_config = hf_config.vision_config

518
519
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
520
521
522
523

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
524
            quant_config=quant_config,
525
            num_hidden_layers_override=num_hidden_layers,
526
            require_post_norm=require_post_norm,
527
            prefix=prefix,
528
529
530
531
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
532
            quant_config=quant_config,
533
            num_hidden_layers_override=num_hidden_layers,
534
            require_post_norm=require_post_norm,
535
            prefix=prefix,
536
        )
537
    elif isinstance(vision_config, PixtralVisionConfig):
538
539
        return PixtralHFVisionModel(
            vision_config,
540
            quant_config=quant_config,
541
542
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
543
            prefix=prefix,
544
        )
545
546
547
548
549

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


550
551
552
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
553
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
554
555
556
557

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
558
    }
559

560
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
561
        super().__init__()
562

563
564
565
566
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

567
        self.config = config
568
        self.multimodal_config = multimodal_config
569

570
571
572
573
574
575
576
577
578
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

579
        # TODO: Optionally initializes this for supporting embeddings.
580
        self.vision_tower = init_vision_tower_for_llava(
581
582
583
            config,
            quant_config,
            require_post_norm=False,
584
            prefix=maybe_prefix(prefix, "vision_tower"))
585
586
587
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
588
            projector_hidden_act=config.projector_hidden_act,
589
            multimodal_projector_bias=config.multimodal_projector_bias,
590
591
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
592

593
        self.language_model = init_vllm_registered_model(
594
            vllm_config=vllm_config,
595
596
597
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
598

599
600
601
602
603
604
605
606
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
607
        return get_sampler()
608

609
610
611
612
613
614
615
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
616
            raise ValueError(
617
618
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
619
620
621
622

        return data

    def _parse_and_validate_image_input(
623
624
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
625
        image_embeds = kwargs.pop("image_embeds", None)
626

627
        if pixel_values is None and image_embeds is None:
628
            return None
629

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
        feat_is_patch = kwargs.pop("feat_is_patch", None)
        if feat_is_patch is not None and not isinstance(
                feat_is_patch, (torch.Tensor, list)):
            raise ValueError("Incorrect type of feat_is_patch. "
                             f"Got type: {type(feat_is_patch)}")

        embed_is_patch = kwargs.pop("embed_is_patch", None)
        if embed_is_patch is not None and not isinstance(
                embed_is_patch, (torch.Tensor, list)):
            raise ValueError("Incorrect type of embed_is_patch. "
                             f"Got type: {type(embed_is_patch)}")

        num_crops = kwargs.pop("num_crops", None)
        if num_crops is not None and not isinstance(num_crops, torch.Tensor):
            raise ValueError("Incorrect type of num_crops. "
                             f"Got type: {type(num_crops)}")

647
        if pixel_values is not None:
648
            if not isinstance(pixel_values, (torch.Tensor, list)):
649
650
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
651

652
653
654
655
            if self.config.vision_config.model_type == "pixtral":
                return LlavaImagePixelInputs(
                    type="pixel_values",
                    data=flatten_bn(pixel_values),
656
657
658
                    feat_is_patch=feat_is_patch,
                    embed_is_patch=embed_is_patch,
                    num_crops=num_crops,
659
660
                )

661
662
            return LlavaImagePixelInputs(
                type="pixel_values",
663
664
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
665
666
667
                feat_is_patch=feat_is_patch,
                embed_is_patch=embed_is_patch,
                num_crops=num_crops,
668
669
670
            )

        if image_embeds is not None:
671
            if not isinstance(image_embeds, (torch.Tensor, list)):
672
673
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
674

675
676
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
677
                data=flatten_bn(image_embeds, concat=True),
678
679
680
                feat_is_patch=feat_is_patch,
                embed_is_patch=embed_is_patch,
                num_crops=num_crops,
681
682
683
            )

        raise AssertionError("This line should be unreachable.")
684
685
686
687
688
689
690
691
692
693
694

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

695
696
    def _image_pixels_to_features(
        self,
697
698
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
699
700
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
701

702
703
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
704
        image_features = vision_tower(pixel_values)
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
721
722
723
724

        if image_input["type"] == "image_embeds":
            return image_input["data"]

725
726
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
727

728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

    def _get_mm_embeds(
            self,
            features: torch.Tensor,  # Shape: (num_crop, num_patch, d)
            feat_is_patch: torch.Tensor,  # Shape: (num_crop, num_patch)
            num_crops: torch.Tensor,  # Shape: (num_images,)
            embed_is_patch: torch.Tensor,  # Shape: (num_embeds,)
    ) -> list[torch.Tensor]:
        """Scatter the patch features into a contiguous tensor that corresponds
        to the embedding tokens defined by the multimodal processor.

        Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
        """

        # Insert columns of nan values according to `feat_is_patch`. This work
        # ideally should be done in `_process_image_input`, but
        # `_process_image_input` is used in both V0 and V1 path. It's safer to
        # put the logic here.
        # FIXME: Move this logic to `_process_image_input` when v0 is
        # deprecated. Merge this function with `Molmo._get_mm_embeds`.
        feat_is_patch = feat_is_patch.view(-1)
        embed_is_patch = embed_is_patch.view(-1)
        expanded_embedding = torch.full(
            (sum(num_crops), *features.shape[1:]),
            torch.nan,
            dtype=features.dtype).to(features.device)
        expanded_embedding[feat_is_patch] = features

        num_crops_per_image = num_crops.tolist()
        feats_per_image = expanded_embedding.split(num_crops_per_image)
        f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)

        embed_dim = expanded_embedding.shape[-1]
        num_embeds = embed_is_patch.shape[0]

        embeds_in_batch = list[torch.Tensor]()
        for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
            embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
            embeds[embed_is_patch] = feats[f_is_patch]
            embeds_in_batch.append(embeds)

        return embeds_in_batch

    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
782
783
784
785
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
786
787
788
789
790
791
792
793
794
        if kwargs.get("v0_path", False):
            return vision_embeddings
        else:
            nested_emb = [
                self._get_mm_embeds(*args) for args in zip(
                    vision_embeddings, image_input["feat_is_patch"],
                    image_input["num_crops"], image_input["embed_is_patch"])
            ]
            return flatten_2d_lists(nested_emb)
795
796
797
798

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
799
        multimodal_embeddings: Optional[NestedTensors] = None,
800
801
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
802
        if multimodal_embeddings is not None:
803
804
805
806
807
808
            # Extract the patch tokens
            patch_embeddings = json_map_leaves(
                lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
                cast(JSONTree[torch.Tensor], multimodal_embeddings),
            )

809
            inputs_embeds = merge_multimodal_embeddings(
810
811
                input_ids, inputs_embeds, cast(NestedTensors,
                                               patch_embeddings),
812
813
814
                self.config.image_token_index)
        return inputs_embeds

815
816
817
818
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
819
        intermediate_tensors: Optional[IntermediateTensors] = None,
820
        inputs_embeds: Optional[torch.Tensor] = None,
821
        **kwargs: object,
822
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
823
        """Run forward pass for LLaVA-1.5.
824
825
826

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
827

828
        Concretely, consider a text prompt:
829
830
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

831
        Tokenizer outputs:
832
833
834
835
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
836
        before they are inputted to the model, so the input processor prepends
837
838
839
840
841
842
843
844
845
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
846
847
848
849
850
851
852

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
853
            pixel_values: The pixels in each input image.
854

855
856
        See also:
            :class:`LlavaImageInputs`
857
        """
858
859
        if intermediate_tensors is not None:
            inputs_embeds = None
860
861
862

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
863
        elif inputs_embeds is None:
864
            kwargs.update({"v0_path": True})
865
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
866
867
868
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
869

870
871
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
872
                                                  intermediate_tensors,
873
                                                  inputs_embeds=inputs_embeds)
874
875
876

        return hidden_states

877
878
879
880
881
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
882
883
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
884
885
886
887
888
889

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
890
        return self.language_model.sample(logits, sampling_metadata)
891

892
893
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
894
        loader = AutoWeightsLoader(self)
895
        return loader.load_weights(weights)
896
897


898
899
class MantisProcessingInfo(LlavaProcessingInfo):

900
    def get_hf_processor(self, **kwargs: object):
901
902
903
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

904
905
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

906
907
908
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
909
            kwargs.setdefault("vision_feature_select_strategy", None)
910
911
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
912
913
914
915
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
916

917
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
918
919


920
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
921

922
923
    def apply(
        self,
924
        prompt: Union[str, list[int]],
925
926
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
927
        return_mm_hashes: bool = False,
928
    ) -> MultiModalInputs:
929
        hf_config = self.info.get_hf_config()
930
        image_token_id = hf_config.image_token_index
931
932

        # Assume that it doesn't depend on the image size
933
        num_image_tokens = self.info.get_num_image_tokens(
934
935
936
            image_width=-1,
            image_height=-1,
        )
937

938
939
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
940

941
942
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
943
944
945
946
947
948
949
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
950
                "<image>" * num_image_tokens,
951
952
953
                "</Image>)",  # 3 tokens
            ])

954
        mantis_mm_repls = self._bind_and_group_updates([
955
956
            PromptReplacement(
                modality="image",
957
                target=[image_token_id] * num_image_tokens,
958
959
960
961
                replacement=get_replacement_mantis,
            )
        ])

962
        prompt_ids, prompt, _ = self._apply_prompt_updates(
963
            result["prompt_token_ids"],
964
            mantis_mm_repls,
965
966
967
            mm_item_counts,
        )

968
        unbound_orig_repls = self._get_prompt_updates(
969
970
971
972
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
973
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
974
975
976
977
978
979
980

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
981

982
983
984
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
985
986
        }

987
        return MultiModalInputs(
988
            type="multimodal",
989
            prompt=prompt,
990
991
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
992
            mm_placeholders=mm_placeholder_ranges,
993
        )
994
995
996
997


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
998
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
999
                                        info=MantisProcessingInfo,
1000
                                        dummy_inputs=LlavaDummyInputsBuilder)
1001
1002
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass