llava.py 37.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
7
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
                    TypeVar, Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.model_executor.layers.activation import get_act_fn
22
23
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
25
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
29
                                    MultiModalInputs, MultiModalKwargs,
30
                                    NestedTensors)
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
                                   ImageSize, MultiModalDataItems)
33
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
                                        BaseProcessingInfo, ProcessingCache,
35
                                        PromptReplacement, PromptUpdate)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
from vllm.sequence import IntermediateTensors
38
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
39

40
from .clip import CLIPVisionModel
41
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
42
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
43
from .siglip import SiglipVisionModel
44
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
45
                    maybe_prefix, merge_multimodal_embeddings)
46
from .vision import get_vision_encoder_info
47
48


49
50
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
51
    pixel_values: torch.Tensor
52
53
54
55
56
57
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
58

59
60
61
62
63
64
65
66
67
68
69
70

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
71
72
73
74
75
76
77
    """
    A boolean mask indicating which image features correspond
    to patch tokens.

    Shape: `(batch_size, num_crops, num_patch)`
    """

78
    embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
79
80
81
82
83
84
85
    """
    A boolean mask indicating which image embeddings correspond
    to patch tokens.
    
    Shape: `(batch_size, num_embeds)`
    """

86
    num_crops: Union[torch.Tensor, list[torch.Tensor]]
87
88
    """Shape: `(batch_size, num_images)`"""

89
90
91
92

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
93
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
94
95
96
97
98

    `hidden_size` must match the hidden size of language model backbone.
    """


99
100
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
101
102


103
104
class LlavaMultiModalProjector(nn.Module):

105
106
107
108
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
109
                 multimodal_projector_bias: bool,
110
111
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
112
113
        super().__init__()

114
115
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
116
                                             bias=multimodal_projector_bias,
117
118
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
119
        self.act = get_act_fn(projector_hidden_act)
120
121
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
122
                                          bias=multimodal_projector_bias,
123
124
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
125

126
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
127
        hidden_states, _ = self.linear_1(image_features)
128
        hidden_states = self.act(hidden_states)
129
        hidden_states, _ = self.linear_2(hidden_states)
130
131
132
        return hidden_states


133
134
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
135
    image_token_index: Final[int]
136
    vision_feature_select_strategy: Final[str]
137
    vision_feature_layer: Final[Union[int, list[int]]]
138

139

140
141
142
143
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


144
class BaseLlavaProcessingInfo(BaseProcessingInfo):
145

146
    def get_hf_config(self) -> LlavaLikeConfig:
147
        return self.ctx.get_hf_config(LlavaConfig)
148

149
150
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
151

152
    @abstractmethod
153
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
154
        raise NotImplementedError
155

156
157
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
158

159
160
161
162
163
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
164
        return {"image": self.get_max_image_tokens()}
165

166
167
168
169
170
171
172
173
174
175
176
177
178
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

179
180
181
182
183
184
185
186
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
187

188
189
190
191
192
193
194
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
195

196
197
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
198
199
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
200

201
202
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
203

204
        return self.get_num_image_tokens(
205
206
207
208
            image_width=target_width,
            image_height=target_height,
        )

209
210
211
212
213
214

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

215
    def get_dummy_processor_inputs(
216
        self,
217
        seq_len: int,
218
219
220
221
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

222
        processor = self.info.get_hf_processor()
223
        image_token = processor.image_token
224
225
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
226
227
228
229
230
231
232
233
234
235
236
237
238
239

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


240
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
241

242
243
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
244
245


246
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
247
248
249
250
251
252
253
254
255

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
256

257
    def _get_prompt_updates(
258
259
260
261
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
262
    ) -> Sequence[PromptUpdate]:
263
        hf_config = self.info.get_hf_config()
264
265
266
267
268
269
270
271
272
273
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
274
                num_image_tokens = self.info.get_num_image_tokens(
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


290
291
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
292

293
294
295
296
297
298
299
300
301
302
303
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


304
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
305

306
307
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
308

309

310
311
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
312

313
314
315
316
317
318
319
320
321
322
323
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
324

325
326
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
327
            # Before/after https://github.com/huggingface/transformers/pull/35122
328
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
350

351
            hf_config = self.info.get_hf_config()
352
353
354
            vision_config = hf_config.vision_config
            assert isinstance(vision_config, PixtralVisionConfig)
            encoder_info = PixtralHFEncoderInfo(vision_config)
355
356

            tile_sizes = [
357
                encoder_info.get_patch_grid_size(
358
                    image_width=pixel_value.shape[-1],
359
360
                    image_height=pixel_value.shape[-2],
                ) for pixel_value in processed_outputs["pixel_values"]
361
362
363
364
365
366
367
368
369
370
371
372
            ]
            num_crops = torch.tensor([(ncols + 1) * nrows
                                      for ncols, nrows in tile_sizes])
            # Each image may result to masks of different sizes, so we need to
            # flatten the list and later use `num_crops` to get per-image masks.
            embed_is_patch = torch.tensor(
                flatten_2d_lists([([True] * ncols + [False]) * nrows
                                  for ncols, nrows in tile_sizes]))
            processed_outputs["num_crops"] = num_crops
            processed_outputs["embed_is_patch"] = embed_is_patch
            processed_outputs["feat_is_patch"] = embed_is_patch

373
        return processed_outputs
374

375
376
377
378
379
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
380
        num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
381
        return dict(
382
383
384
385
386
            feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
                "image", num_crops),
            embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
                "image", num_crops),
            num_crops=MultiModalFieldConfig.batched("image"),
387
388
389
390
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

391
    def _get_prompt_updates(
392
393
        self,
        mm_items: MultiModalDataItems,
394
395
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
396
    ) -> Sequence[PromptUpdate]:
397
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
398
        hf_config = self.info.get_hf_config()
399
400
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
401

402
403
404
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
405

406
407
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
408
        encoder_info = PixtralHFEncoderInfo(vision_config)
409

410
411
412
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
413

414
            ncols, nrows = encoder_info.get_patch_grid_size(
415
416
417
                image_width=image_size.width,
                image_height=image_size.height,
            )
418

419
420
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
421

422
            return tokens
423
424
425
426
427

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
428
429
                replacement=get_replacement,
            ),
430
431
        ]

432

433
434
435
436
437
438
439
440
441
442
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


443
def _build_llava_or_pixtral_hf_processor(
444
445
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
446
447
448
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
449
) -> BaseMultiModalProcessor:
450
    if isinstance(info, PixtralHFProcessingInfo):
451
        return PixtralHFMultiModalProcessor(
452
453
454
455
456
457
458
459
460
461
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
462
463
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
464
        )
465

466
    raise NotImplementedError(type(info))
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
490
    """Given a signed vision feature layer, get the number of hidden layers
491
492
493
494
495
496
497
498
499
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
500
    return feature_layer_index
501
502
503
504
505
506
507


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
508
    prefix: str = "",
509
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
510
511
    vision_config = hf_config.vision_config

512
513
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
514
515
516
517

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
518
            quant_config=quant_config,
519
            num_hidden_layers_override=num_hidden_layers,
520
            require_post_norm=require_post_norm,
521
            prefix=prefix,
522
523
524
525
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
526
            quant_config=quant_config,
527
            num_hidden_layers_override=num_hidden_layers,
528
            require_post_norm=require_post_norm,
529
            prefix=prefix,
530
        )
531
    elif isinstance(vision_config, PixtralVisionConfig):
532
533
        return PixtralHFVisionModel(
            vision_config,
534
            quant_config=quant_config,
535
536
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
537
            prefix=prefix,
538
        )
539
540
541
542
543

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


544
545
546
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
547
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
548
549
550
551

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
552
    }
553

554
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
555
        super().__init__()
556

557
558
559
560
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

561
        self.config = config
562
        self.multimodal_config = multimodal_config
563

564
565
566
567
568
569
570
571
572
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

573
        # TODO: Optionally initializes this for supporting embeddings.
574
        self.vision_tower = init_vision_tower_for_llava(
575
576
577
            config,
            quant_config,
            require_post_norm=False,
578
            prefix=maybe_prefix(prefix, "vision_tower"))
579
580
581
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
582
            projector_hidden_act=config.projector_hidden_act,
583
            multimodal_projector_bias=config.multimodal_projector_bias,
584
585
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
586

587
        self.language_model = init_vllm_registered_model(
588
            vllm_config=vllm_config,
589
590
591
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
592

593
594
595
596
597
598
599
600
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
601
        return get_sampler()
602

603
604
605
606
607
608
609
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
610
            raise ValueError(
611
612
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
613
614
615
616

        return data

    def _parse_and_validate_image_input(
617
618
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
619
        image_embeds = kwargs.pop("image_embeds", None)
620

621
        if pixel_values is None and image_embeds is None:
622
            return None
623

624
        if pixel_values is not None:
625
            if not isinstance(pixel_values, (torch.Tensor, list)):
626
627
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
628

629
            if self.config.vision_config.model_type == "pixtral":
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
                feat_is_patch = kwargs.pop("feat_is_patch")
                if not isinstance(feat_is_patch, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of feat_is_patch. "
                                     f"Got type: {type(feat_is_patch)}")

                embed_is_patch = kwargs.pop("embed_is_patch")
                if not isinstance(embed_is_patch, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of embed_is_patch. "
                                     f"Got type: {type(embed_is_patch)}")

                num_crops = kwargs.pop("num_crops")
                if not isinstance(num_crops, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of num_crops. "
                                     f"Got type: {type(num_crops)}")

                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
648
649
650
                    feat_is_patch=feat_is_patch,
                    embed_is_patch=embed_is_patch,
                    num_crops=num_crops,
651
652
                )

653
654
            return LlavaImagePixelInputs(
                type="pixel_values",
655
                pixel_values=self._validate_pixel_values(
656
                    flatten_bn(pixel_values, concat=True)),
657
658
659
            )

        if image_embeds is not None:
660
            if not isinstance(image_embeds, (torch.Tensor, list)):
661
662
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
663

664
665
666
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

667
668
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
669
                data=flatten_bn(image_embeds, concat=True),
670
671
672
            )

        raise AssertionError("This line should be unreachable.")
673
674
675
676
677
678
679
680
681
682
683

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

684
685
    def _image_pixels_to_features(
        self,
686
687
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
688
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
689
    ) -> torch.Tensor:
690

691
692
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
693
        image_features = vision_tower(pixel_values)
694
695
696
697
698
699

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

700
701
702
703
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
    ) -> torch.Tensor:
704
705
        assert self.vision_tower is not None

706
        pixel_values = inputs["pixel_values"]
707
708
709

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

710
711
712
713
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
714
715
716
        if image_input["type"] == "image_embeds":
            return image_input["data"]

717
718
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
719

720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

    def _get_mm_embeds(
            self,
            features: torch.Tensor,  # Shape: (num_crop, num_patch, d)
            feat_is_patch: torch.Tensor,  # Shape: (num_crop, num_patch)
            num_crops: torch.Tensor,  # Shape: (num_images,)
            embed_is_patch: torch.Tensor,  # Shape: (num_embeds,)
    ) -> list[torch.Tensor]:
        """Scatter the patch features into a contiguous tensor that corresponds
        to the embedding tokens defined by the multimodal processor.

        Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
        """

        # Insert columns of nan values according to `feat_is_patch`. This work
        # ideally should be done in `_process_image_input`, but
        # `_process_image_input` is used in both V0 and V1 path. It's safer to
        # put the logic here.
        # FIXME: Move this logic to `_process_image_input` when v0 is
        # deprecated. Merge this function with `Molmo._get_mm_embeds`.
        feat_is_patch = feat_is_patch.view(-1)
        embed_is_patch = embed_is_patch.view(-1)
        expanded_embedding = torch.full(
            (sum(num_crops), *features.shape[1:]),
            torch.nan,
            dtype=features.dtype).to(features.device)
        expanded_embedding[feat_is_patch] = features

        num_crops_per_image = num_crops.tolist()
        feats_per_image = expanded_embedding.split(num_crops_per_image)
        f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)

        embed_dim = expanded_embedding.shape[-1]
        num_embeds = embed_is_patch.shape[0]

        embeds_in_batch = list[torch.Tensor]()
        for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
            embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
            embeds[embed_is_patch] = feats[f_is_patch]
            embeds_in_batch.append(embeds)

        return embeds_in_batch

773
774
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
775
776
777
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
778

779
        vision_embeddings = self._process_image_input(image_input)
780

781
782
        if (kwargs.get("v0_path", False)
                or image_input["type"] != "pixel_values_pixtral"):
783
            # The path is used for pixtral (V0 only) and llava (V0/V1)
784
            return vision_embeddings
785
786
787
788
789
790
791

        nested_emb = [
            self._get_mm_embeds(*args) for args in zip(
                vision_embeddings, image_input["feat_is_patch"],
                image_input["num_crops"], image_input["embed_is_patch"])
        ]
        return flatten_2d_lists(nested_emb)
792
793
794
795

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
796
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
797
798
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
799
        if multimodal_embeddings is not None:
800
801
802
803
804
805
            # Extract the patch tokens
            patch_embeddings = json_map_leaves(
                lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
                cast(JSONTree[torch.Tensor], multimodal_embeddings),
            )

806
            inputs_embeds = merge_multimodal_embeddings(
807
808
                input_ids, inputs_embeds, cast(NestedTensors,
                                               patch_embeddings),
809
810
811
                self.config.image_token_index)
        return inputs_embeds

812
813
814
815
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
816
        intermediate_tensors: Optional[IntermediateTensors] = None,
817
        inputs_embeds: Optional[torch.Tensor] = None,
818
        **kwargs: object,
819
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
820
        """Run forward pass for LLaVA-1.5.
821
822
823

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
824

825
        Concretely, consider a text prompt:
826
827
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

828
        Tokenizer outputs:
829
830
831
832
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
833
        before they are inputted to the model, so the input processor prepends
834
835
836
837
838
839
840
841
842
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
843
844
845
846
847
848
849

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
850
            pixel_values: The pixels in each input image.
851

852
853
        See also:
            :class:`LlavaImageInputs`
854
        """
855
856
        if intermediate_tensors is not None:
            inputs_embeds = None
857
858
859

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
860
        elif inputs_embeds is None:
861
            kwargs.update({"v0_path": True})
862
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
863
864
865
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
866

867
868
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
869
                                                  intermediate_tensors,
870
                                                  inputs_embeds=inputs_embeds)
871
872
873

        return hidden_states

874
875
876
877
878
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
879
880
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
881
882
883
884
885
886

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
887
        return self.language_model.sample(logits, sampling_metadata)
888

889
890
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
891
        loader = AutoWeightsLoader(self)
892
        return loader.load_weights(weights)
893
894


895
896
class MantisProcessingInfo(LlavaProcessingInfo):

897
    def get_hf_processor(self, **kwargs: object):
898
899
900
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

901
902
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

903
904
905
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
906
            kwargs.setdefault("vision_feature_select_strategy", None)
907
908
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
909
910
911
912
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
913

914
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
915
916


917
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
918

919
920
    def apply(
        self,
921
        prompt: Union[str, list[int]],
922
923
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
924
        return_mm_hashes: bool = False,
925
    ) -> MultiModalInputs:
926
        hf_config = self.info.get_hf_config()
927
        image_token_id = hf_config.image_token_index
928
929

        # Assume that it doesn't depend on the image size
930
        num_image_tokens = self.info.get_num_image_tokens(
931
932
933
            image_width=-1,
            image_height=-1,
        )
934

935
936
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
937

938
939
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
940
941
942
943
944
945
946
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
947
                "<image>" * num_image_tokens,
948
949
950
                "</Image>)",  # 3 tokens
            ])

951
        mantis_mm_repls = self._bind_and_group_updates([
952
953
            PromptReplacement(
                modality="image",
954
                target=[image_token_id] * num_image_tokens,
955
956
957
958
                replacement=get_replacement_mantis,
            )
        ])

959
        prompt_ids, prompt, _ = self._apply_prompt_updates(
960
            result["prompt_token_ids"],
961
            mantis_mm_repls,
962
963
964
            mm_item_counts,
        )

965
        unbound_orig_repls = self._get_prompt_updates(
966
967
968
969
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
970
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
971
972
973
974
975
976
977

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
978

979
980
981
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
982
983
        }

984
        return MultiModalInputs(
985
            type="multimodal",
986
            prompt=prompt,
987
988
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
989
            mm_placeholders=mm_placeholder_ranges,
990
        )
991
992
993
994


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
995
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
996
                                        info=MantisProcessingInfo,
997
                                        dummy_inputs=LlavaDummyInputsBuilder)
998
999
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass