llava.py 29.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from functools import cached_property
5
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
6
                    Protocol, Set, Tuple, TypedDict, TypeVar, Union)
7
8

import torch
9
import torch.nn as nn
10
from packaging.version import Version
11
12
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
13
                          SiglipVisionConfig)
14
from transformers import __version__ as TRANSFORMERS_VERSION
15
from transformers.models.llava import LlavaProcessor
16
from transformers.models.pixtral import PixtralProcessor
17
18

from vllm.attention import AttentionMetadata
19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.model_executor.layers.activation import get_act_fn
22
23
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
25
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
29
                                    MultiModalInputs, MultiModalKwargs,
30
                                    NestedTensors)
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
                                   ImageSize, MultiModalDataItems)
33
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
35
36
                                        BaseProcessingInfo, ProcessingCache,
                                        PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
from vllm.sequence import IntermediateTensors
38

39
from .clip import CLIPVisionModel
40
from .interfaces import SupportsMultiModal, SupportsPP
41
42
43
from .pixtral import (PixtralHFVisionModel,
                      get_pixtral_hf_image_feature_grid_size)
from .siglip import SiglipVisionModel
44
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
45
                    maybe_prefix, merge_multimodal_embeddings)
46
from .vision import get_vision_encoder_info
47
48


49
50
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
51
52
53
54
55
56
57
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
58
59
60
61
62


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
63
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
64
65
66
67
68
69
70
71

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


72
73
class LlavaMultiModalProjector(nn.Module):

74
75
76
77
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
78
                 multimodal_projector_bias: bool,
79
80
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
81
82
        super().__init__()

83
84
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
85
                                             bias=multimodal_projector_bias,
86
87
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
88
        self.act = get_act_fn(projector_hidden_act)
89
90
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
91
                                          bias=multimodal_projector_bias,
92
93
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
94

95
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
96
        hidden_states, _ = self.linear_1(image_features)
97
        hidden_states = self.act(hidden_states)
98
        hidden_states, _ = self.linear_2(hidden_states)
99
100
101
        return hidden_states


102
103
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
104
    image_token_index: Final[int]
105
    vision_feature_select_strategy: Final[str]
106
    vision_feature_layer: Final[Union[int, list[int]]]
107

108

109
110
111
112
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


113
class BaseLlavaProcessingInfo(BaseProcessingInfo):
114

115
    def get_hf_config(self) -> LlavaLikeConfig:
116
        return self.ctx.get_hf_config(LlavaConfig)
117

118
119
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
120

121
    @abstractmethod
122
    def get_hf_processor(self) -> LlavaLikeProcessor:
123
        raise NotImplementedError
124

125
126
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
127

128
129
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        return {"image": self.get_max_image_tokens()}
130

131
132
133
134
135
136
137
138
139
140
141
142
143
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

144
145
146
147
148
149
150
151
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
152

153
154
155
156
157
158
159
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
160

161
162
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
163
164
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
165

166
167
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
168

169
        return self.get_num_image_tokens(
170
171
172
173
            image_width=target_width,
            image_height=target_height,
        )

174
175
176
177
178
179

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

180
    def get_dummy_processor_inputs(
181
        self,
182
        seq_len: int,
183
184
185
186
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

187
        processor = self.info.get_hf_processor()
188
        image_token = processor.image_token
189
190
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
191
192
193
194
195
196
197
198
199
200
201
202
203
204

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


205
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
206

207
    def get_hf_processor(self):
208
209
210
        return self.ctx.get_hf_processor(LlavaProcessor)


211
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
212
213
214
215
216
217
218
219
220

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
221
222
223
224
225
226
227

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
228
        hf_config = self.info.get_hf_config()
229
230
231
232
233
234
235
236
237
238
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
239
                num_image_tokens = self.info.get_num_image_tokens(
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


255
256
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
257

258
259
260
261
262
263
264
265
266
267
268
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


269
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
270

271
    def get_hf_processor(self):
272
273
        return self.ctx.get_hf_processor(PixtralProcessor)

274

275
276
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
277

278
279
280
281
282
283
284
285
286
287
288
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
289

290
291
292
293
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
            images = mm_data["images"]
            assert isinstance(images, list)
294

295
296
297
298
299
            # Original output: (1, num_images, C, H, W)
            # New output: (num_images, C, H, W)
            assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
            assert (isinstance(pixel_values[0], list)
                    and len(pixel_values[0]) == len(images))
300

301
            processed_outputs["pixel_values"] = pixel_values[0]
302

303
        return processed_outputs
304

305
306
307
308
309
310
311
312
313
314
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

315
316
317
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
318
319
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
320
    ) -> list[PromptReplacement]:
321
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
322
        hf_config = self.info.get_hf_config()
323
324
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
325

326
327
328
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
329

330
331
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
332

333
334
335
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
336

337
338
339
340
341
            ncols, nrows = get_pixtral_hf_image_feature_grid_size(
                vision_config,
                image_width=image_size.width,
                image_height=image_size.height,
            )
342

343
344
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
345

346
            return tokens
347
348
349
350
351

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
352
353
                replacement=get_replacement,
            ),
354
355
        ]

356

357
358
359
360
361
362
363
364
365
366
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


367
def _build_llava_or_pixtral_hf_processor(
368
369
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
370
371
372
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
373
) -> BaseMultiModalProcessor:
374
    if isinstance(info, PixtralHFProcessingInfo):
375
        return PixtralHFMultiModalProcessor(
376
377
378
379
380
381
382
383
384
385
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
386
387
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
388
        )
389

390
    raise NotImplementedError(type(info))
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
425
426
427
428
429
430
431


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
432
    prefix: str = "",
433
):
434
435
    vision_config = hf_config.vision_config

436
437
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
438
439
440
441

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
442
            quant_config=quant_config,
443
            num_hidden_layers_override=num_hidden_layers,
444
            require_post_norm=require_post_norm,
445
            prefix=prefix,
446
447
448
449
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
450
            quant_config=quant_config,
451
            num_hidden_layers_override=num_hidden_layers,
452
            require_post_norm=require_post_norm,
453
            prefix=prefix,
454
        )
455
    elif isinstance(vision_config, PixtralVisionConfig):
456
457
        return PixtralHFVisionModel(
            vision_config,
458
            quant_config=quant_config,
459
460
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
461
            prefix=prefix,
462
        )
463
464
465
466
467

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


468
469
470
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
471
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
472
473
474
475

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
476
    }
477

478
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
479
        super().__init__()
480

481
482
483
484
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

485
        self.config = config
486
        self.multimodal_config = multimodal_config
487

488
489
490
491
492
493
494
495
496
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

497
        # TODO: Optionally initializes this for supporting embeddings.
498
        self.vision_tower = init_vision_tower_for_llava(
499
500
501
            config,
            quant_config,
            require_post_norm=False,
502
            prefix=maybe_prefix(prefix, "vision_tower"))
503
504
505
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
506
            projector_hidden_act=config.projector_hidden_act,
507
            multimodal_projector_bias=config.multimodal_projector_bias,
508
509
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
510

511
        self.language_model = init_vllm_registered_model(
512
            vllm_config=vllm_config,
513
514
515
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
516

517
518
519
520
521
522
523
524
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
525
        return get_sampler()
526

527
528
529
530
531
532
533
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
534
            raise ValueError(
535
536
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
537
538
539
540

        return data

    def _parse_and_validate_image_input(
541
542
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
543
        image_embeds = kwargs.pop("image_embeds", None)
544

545
        if pixel_values is None and image_embeds is None:
546
            return None
547

548
        if pixel_values is not None:
549
            if not isinstance(pixel_values, (torch.Tensor, list)):
550
551
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
552

553
554
555
556
557
558
            if self.config.vision_config.model_type == "pixtral":
                return LlavaImagePixelInputs(
                    type="pixel_values",
                    data=flatten_bn(pixel_values),
                )

559
560
            return LlavaImagePixelInputs(
                type="pixel_values",
561
562
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
563
564
565
            )

        if image_embeds is not None:
566
            if not isinstance(image_embeds, (torch.Tensor, list)):
567
568
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
569

570
571
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
572
                data=flatten_bn(image_embeds, concat=True),
573
574
575
            )

        raise AssertionError("This line should be unreachable.")
576
577
578
579
580
581
582
583
584
585
586

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

587
588
    def _image_pixels_to_features(
        self,
589
590
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
591
592
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
593

594
595
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
596
        image_features = vision_tower(pixel_values)
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
613
614
615
616

        if image_input["type"] == "image_embeds":
            return image_input["data"]

617
618
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
619
620
        return self.multi_modal_projector(image_features)

621
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
622
623
624
625
626
627
628
629
630
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
631
        multimodal_embeddings: Optional[NestedTensors] = None,
632
633
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
634
        if multimodal_embeddings is not None:
635
            inputs_embeds = merge_multimodal_embeddings(
636
                input_ids, inputs_embeds, multimodal_embeddings,
637
638
639
                self.config.image_token_index)
        return inputs_embeds

640
641
642
643
644
645
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
646
        intermediate_tensors: Optional[IntermediateTensors] = None,
647
        inputs_embeds: Optional[torch.Tensor] = None,
648
        **kwargs: object,
649
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
650
        """Run forward pass for LLaVA-1.5.
651
652
653

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
654

655
        Concretely, consider a text prompt:
656
657
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

658
        Tokenizer outputs:
659
660
661
662
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
663
        before they are inputted to the model, so the input processor prepends
664
665
666
667
668
669
670
671
672
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
673
674
675
676
677
678
679

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
680
            pixel_values: The pixels in each input image.
681

682
683
        See also:
            :class:`LlavaImageInputs`
684
        """
685
686
        if intermediate_tensors is not None:
            inputs_embeds = None
687
688
689

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
690
        elif inputs_embeds is None:
691
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
692
693
694
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
695

696
697
698
699
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
700
                                                  intermediate_tensors,
701
                                                  inputs_embeds=inputs_embeds)
702
703
704

        return hidden_states

705
706
707
708
709
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
710
711
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
712
713
714
715
716
717

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
718
        return self.language_model.sample(logits, sampling_metadata)
719

720
721
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
722
        loader = AutoWeightsLoader(self)
723
        return loader.load_weights(weights)
724
725


726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
class MantisProcessingInfo(LlavaProcessingInfo):

    def get_hf_processor(self):
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
            vision_feature_select_strategy = None
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
            vision_feature_select_strategy = hf_config.vision_feature_select_strategy  # noqa: E501

        return self.ctx.get_hf_processor(
            LlavaProcessor,
            patch_size=vision_info.get_patch_size(),
            vision_feature_select_strategy=vision_feature_select_strategy,
        )


747
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
748

749
750
    def apply(
        self,
751
        prompt: Union[str, list[int]],
752
753
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
754
    ) -> MultiModalInputs:
755
        hf_config = self.info.get_hf_config()
756
        image_token_id = hf_config.image_token_index
757
758

        # Assume that it doesn't depend on the image size
759
        num_image_tokens = self.info.get_num_image_tokens(
760
761
762
            image_width=-1,
            image_height=-1,
        )
763

764
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
765

766
767
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
768
769
770
771
772
773
774
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
775
                "<image>" * num_image_tokens,
776
777
778
                "</Image>)",  # 3 tokens
            ])

779
        mantis_mm_repls = self._bind_and_group_repls([
780
781
            PromptReplacement(
                modality="image",
782
                target=[image_token_id] * num_image_tokens,
783
784
785
786
                replacement=get_replacement_mantis,
            )
        ])

787
        prompt_ids, prompt, _ = self._apply_prompt_replacements(
788
            result["prompt_token_ids"],
789
            mantis_mm_repls,
790
791
792
793
794
795
796
797
            mm_item_counts,
        )

        unbound_orig_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
798
799
800
801
802
803
804
        orig_repls = self._bind_and_group_repls(unbound_orig_repls)

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
805

806
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
807

808
809
810
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
811
812
        }

813
        return MultiModalInputs(
814
            type="multimodal",
815
            prompt=prompt,
816
817
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
818
            mm_placeholders=mm_placeholder_ranges,
819
        )
820
821
822
823


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
824
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
825
                                        info=MantisProcessingInfo,
826
                                        dummy_inputs=LlavaDummyInputsBuilder)
827
828
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass