llava.py 29.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from functools import cached_property
5
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
6
                    Protocol, Set, Tuple, TypedDict, TypeVar, Union)
7
8

import torch
9
import torch.nn as nn
10
from packaging.version import Version
11
12
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
13
                          SiglipVisionConfig)
14
from transformers import __version__ as TRANSFORMERS_VERSION
15
from transformers.models.llava import LlavaProcessor
16
from transformers.models.pixtral import PixtralProcessor
17
18

from vllm.attention import AttentionMetadata
19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.model_executor.layers.activation import get_act_fn
22
23
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
25
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
29
                                    MultiModalInputs, MultiModalKwargs,
30
                                    NestedTensors)
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
                                   ImageSize, MultiModalDataItems)
33
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
35
36
                                        BaseProcessingInfo, ProcessingCache,
                                        PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
from vllm.sequence import IntermediateTensors
38

39
from .clip import CLIPVisionModel
40
from .interfaces import SupportsMultiModal, SupportsPP
41
42
43
from .pixtral import (PixtralHFVisionModel,
                      get_pixtral_hf_image_feature_grid_size)
from .siglip import SiglipVisionModel
44
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
45
                    maybe_prefix, merge_multimodal_embeddings)
46
from .vision import get_vision_encoder_info
47
48


49
50
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
51
52
53
54
55
56
57
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
58
59
60
61
62


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
63
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
64
65
66
67
68
69
70
71

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


72
73
class LlavaMultiModalProjector(nn.Module):

74
75
76
77
78
79
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
80
81
        super().__init__()

82
83
84
85
86
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
                                             bias=True,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
87
        self.act = get_act_fn(projector_hidden_act)
88
89
90
91
92
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
93

94
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
95
        hidden_states, _ = self.linear_1(image_features)
96
        hidden_states = self.act(hidden_states)
97
        hidden_states, _ = self.linear_2(hidden_states)
98
99
100
        return hidden_states


101
102
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
103
    image_token_index: Final[int]
104
    vision_feature_select_strategy: Final[str]
105
    vision_feature_layer: Final[Union[int, list[int]]]
106

107

108
109
110
111
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


112
class BaseLlavaProcessingInfo(BaseProcessingInfo):
113

114
    def get_hf_config(self) -> LlavaLikeConfig:
115
        return self.ctx.get_hf_config(LlavaConfig)
116

117
118
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
119

120
    @abstractmethod
121
    def get_hf_processor(self) -> LlavaLikeProcessor:
122
        raise NotImplementedError
123

124
125
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
126

127
128
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        return {"image": self.get_max_image_tokens()}
129

130
131
132
133
134
135
136
137
138
139
140
141
142
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

143
144
145
146
147
148
149
150
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
151

152
153
154
155
156
157
158
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
159

160
161
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
162
163
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
164

165
166
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
167

168
        return self.get_num_image_tokens(
169
170
171
172
            image_width=target_width,
            image_height=target_height,
        )

173
174
175
176
177
178

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

179
    def get_dummy_processor_inputs(
180
        self,
181
        seq_len: int,
182
183
184
185
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

186
        processor = self.info.get_hf_processor()
187
        image_token = processor.image_token
188
189
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
190
191
192
193
194
195
196
197
198
199
200
201
202
203

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


204
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
205

206
    def get_hf_processor(self):
207
208
209
        return self.ctx.get_hf_processor(LlavaProcessor)


210
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
211
212
213
214
215
216
217
218
219

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
220
221
222
223
224
225
226

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
227
        hf_config = self.info.get_hf_config()
228
229
230
231
232
233
234
235
236
237
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
238
                num_image_tokens = self.info.get_num_image_tokens(
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


254
255
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
256

257
258
259
260
261
262
263
264
265
266
267
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


268
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
269

270
    def get_hf_processor(self):
271
272
        return self.ctx.get_hf_processor(PixtralProcessor)

273

274
275
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
276

277
278
279
280
281
282
283
284
285
286
287
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
288

289
290
291
292
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
            images = mm_data["images"]
            assert isinstance(images, list)
293

294
295
296
297
298
            # Original output: (1, num_images, C, H, W)
            # New output: (num_images, C, H, W)
            assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
            assert (isinstance(pixel_values[0], list)
                    and len(pixel_values[0]) == len(images))
299

300
            processed_outputs["pixel_values"] = pixel_values[0]
301

302
        return processed_outputs
303

304
305
306
307
308
309
310
311
312
313
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

314
315
316
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
317
318
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
319
    ) -> list[PromptReplacement]:
320
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
321
        hf_config = self.info.get_hf_config()
322
323
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
324

325
326
327
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
328

329
330
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
331

332
333
334
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
335

336
337
338
339
340
            ncols, nrows = get_pixtral_hf_image_feature_grid_size(
                vision_config,
                image_width=image_size.width,
                image_height=image_size.height,
            )
341

342
343
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
344

345
            return tokens
346
347
348
349
350

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
351
352
                replacement=get_replacement,
            ),
353
354
        ]

355

356
357
358
359
360
361
362
363
364
365
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


366
def _build_llava_or_pixtral_hf_processor(
367
368
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
369
370
371
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
372
) -> BaseMultiModalProcessor:
373
    if isinstance(info, PixtralHFProcessingInfo):
374
        return PixtralHFMultiModalProcessor(
375
376
377
378
379
380
381
382
383
384
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
385
386
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
387
        )
388

389
    raise NotImplementedError(type(info))
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
424
425
426
427
428
429
430


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
431
    prefix: str = "",
432
):
433
434
    vision_config = hf_config.vision_config

435
436
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
437
438
439
440

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
441
            quant_config=quant_config,
442
            num_hidden_layers_override=num_hidden_layers,
443
            require_post_norm=require_post_norm,
444
            prefix=prefix,
445
446
447
448
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
449
            quant_config=quant_config,
450
            num_hidden_layers_override=num_hidden_layers,
451
            require_post_norm=require_post_norm,
452
            prefix=prefix,
453
        )
454
    elif isinstance(vision_config, PixtralVisionConfig):
455
456
        return PixtralHFVisionModel(
            vision_config,
457
            quant_config=quant_config,
458
459
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
460
            prefix=prefix,
461
        )
462
463
464
465
466

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


467
468
469
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
470
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
471
472
473
474

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
475
    }
476

477
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
478
        super().__init__()
479

480
481
482
483
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

484
        self.config = config
485
        self.multimodal_config = multimodal_config
486

487
488
489
490
491
492
493
494
495
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

496
        # TODO: Optionally initializes this for supporting embeddings.
497
        self.vision_tower = init_vision_tower_for_llava(
498
499
500
            config,
            quant_config,
            require_post_norm=False,
501
            prefix=maybe_prefix(prefix, "vision_tower"))
502
503
504
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
505
506
507
            projector_hidden_act=config.projector_hidden_act,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
508

509
        self.language_model = init_vllm_registered_model(
510
            vllm_config=vllm_config,
511
512
513
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
514

515
516
517
518
519
520
521
522
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
523
        return get_sampler()
524

525
526
527
528
529
530
531
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
532
            raise ValueError(
533
534
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
535
536
537
538

        return data

    def _parse_and_validate_image_input(
539
540
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
541
        image_embeds = kwargs.pop("image_embeds", None)
542

543
        if pixel_values is None and image_embeds is None:
544
            return None
545

546
        if pixel_values is not None:
547
            if not isinstance(pixel_values, (torch.Tensor, list)):
548
549
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
550

551
552
553
554
555
556
            if self.config.vision_config.model_type == "pixtral":
                return LlavaImagePixelInputs(
                    type="pixel_values",
                    data=flatten_bn(pixel_values),
                )

557
558
            return LlavaImagePixelInputs(
                type="pixel_values",
559
560
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
561
562
563
            )

        if image_embeds is not None:
564
            if not isinstance(image_embeds, (torch.Tensor, list)):
565
566
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
567

568
569
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
570
                data=flatten_bn(image_embeds, concat=True),
571
572
573
            )

        raise AssertionError("This line should be unreachable.")
574
575
576
577
578
579
580
581
582
583
584

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

585
586
    def _image_pixels_to_features(
        self,
587
588
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
589
590
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
591

592
593
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
594
        image_features = vision_tower(pixel_values)
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
611
612
613
614

        if image_input["type"] == "image_embeds":
            return image_input["data"]

615
616
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
617
618
        return self.multi_modal_projector(image_features)

619
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
620
621
622
623
624
625
626
627
628
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
629
        multimodal_embeddings: Optional[NestedTensors] = None,
630
631
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
632
        if multimodal_embeddings is not None:
633
            inputs_embeds = merge_multimodal_embeddings(
634
                input_ids, inputs_embeds, multimodal_embeddings,
635
636
637
                self.config.image_token_index)
        return inputs_embeds

638
639
640
641
642
643
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
644
        intermediate_tensors: Optional[IntermediateTensors] = None,
645
        inputs_embeds: Optional[torch.Tensor] = None,
646
        **kwargs: object,
647
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
648
        """Run forward pass for LLaVA-1.5.
649
650
651

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
652

653
        Concretely, consider a text prompt:
654
655
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

656
        Tokenizer outputs:
657
658
659
660
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
661
        before they are inputted to the model, so the input processor prepends
662
663
664
665
666
667
668
669
670
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
671
672
673
674
675
676
677

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
678
            pixel_values: The pixels in each input image.
679

680
681
        See also:
            :class:`LlavaImageInputs`
682
        """
683
684
        if intermediate_tensors is not None:
            inputs_embeds = None
685
686
687

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
688
        elif inputs_embeds is None:
689
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
690
691
692
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
693

694
695
696
697
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
698
                                                  intermediate_tensors,
699
                                                  inputs_embeds=inputs_embeds)
700
701
702

        return hidden_states

703
704
705
706
707
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
708
709
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
710
711
712
713
714
715

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
716
        return self.language_model.sample(logits, sampling_metadata)
717

718
719
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
720
        loader = AutoWeightsLoader(self)
721
        return loader.load_weights(weights)
722
723


724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
class MantisProcessingInfo(LlavaProcessingInfo):

    def get_hf_processor(self):
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
            vision_feature_select_strategy = None
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
            vision_feature_select_strategy = hf_config.vision_feature_select_strategy  # noqa: E501

        return self.ctx.get_hf_processor(
            LlavaProcessor,
            patch_size=vision_info.get_patch_size(),
            vision_feature_select_strategy=vision_feature_select_strategy,
        )


745
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
746

747
748
    def apply(
        self,
749
        prompt: Union[str, list[int]],
750
751
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
752
    ) -> MultiModalInputs:
753
        hf_config = self.info.get_hf_config()
754
        image_token_id = hf_config.image_token_index
755
756

        # Assume that it doesn't depend on the image size
757
        num_image_tokens = self.info.get_num_image_tokens(
758
759
760
            image_width=-1,
            image_height=-1,
        )
761

762
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
763

764
765
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
766
767
768
769
770
771
772
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
773
                "<image>" * num_image_tokens,
774
775
776
                "</Image>)",  # 3 tokens
            ])

777
        mantis_mm_repls = self._bind_and_group_repls([
778
779
            PromptReplacement(
                modality="image",
780
                target=[image_token_id] * num_image_tokens,
781
782
783
784
                replacement=get_replacement_mantis,
            )
        ])

785
        prompt_ids, prompt, _ = self._apply_prompt_replacements(
786
            result["prompt_token_ids"],
787
            mantis_mm_repls,
788
789
790
791
792
793
794
795
            mm_item_counts,
        )

        unbound_orig_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
796
797
798
799
800
801
802
        orig_repls = self._bind_and_group_repls(unbound_orig_repls)

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
803

804
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
805

806
807
808
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
809
810
        }

811
        return MultiModalInputs(
812
            type="multimodal",
813
            prompt=prompt,
814
815
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
816
            mm_placeholders=mm_placeholder_ranges,
817
        )
818
819
820
821


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
822
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
823
                                        info=MantisProcessingInfo,
824
                                        dummy_inputs=LlavaDummyInputsBuilder)
825
826
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass