"vllm/vscode:/vscode.git/clone" did not exist on "4dc24bc8bb9f21422297824f14dc367279684dc5"
llava_next.py 22.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping
6
from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union
7
8
9

import torch
import torch.nn as nn
10
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
11
from transformers.models.llava_next.modeling_llava_next import (
12
13
14
    get_anyres_image_grid_shape,
    unpad_image,
)
15

16
from vllm.config import VllmConfig
17
from vllm.multimodal import MULTIMODAL_REGISTRY
18
from vllm.multimodal.inputs import MultiModalFieldConfig
19
from vllm.multimodal.parse import ImageSize
20
from vllm.sequence import IntermediateTensors
21
from vllm.utils.tensor_schema import TensorSchema, TensorShape
22

23
from .clip import CLIPVisionModel
24
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
25
26
27
28
29
30
31
32
from .llava import (
    BaseLlavaMultiModalProcessor,
    BaseLlavaProcessingInfo,
    LlavaDummyInputsBuilder,
    LlavaLikeConfig,
    LlavaMultiModalProjector,
    init_vision_tower_for_llava,
)
33
from .siglip import SiglipVisionModel
34
35
36
37
38
39
40
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
41
from .vision import get_num_selected_vision_tokens
42
43


44
class LlavaNextImagePixelInputs(TensorSchema):
45
    """
46
47
48
49
50
51
    Dimensions:
        - bn: Batch size * number of images
        - np: Number of patches + 1
        - c: Number of channels (3)
        - h: Height
        - w: Width
52

53
54
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
55
    """
56

57
58
59
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
60
61
        TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}),
    ]
62

63
64
    image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
    # This should be in `(height, width)` format.
65

66

67
class LlavaNextImageEmbeddingInputs(TensorSchema):
68
    """
69
70
71
72
73
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
74

75
76
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
77
78


79
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageEmbeddingInputs]
80
81


82
83
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
    image_grid_pinpoints: Final[list[list[int]]]
84

85

86
87
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
    def get_hf_config(self) -> LlavaNextLikeConfig:
88
        return self.ctx.get_hf_config(LlavaNextConfig)
89

90
91
    def get_hf_processor(self, **kwargs: object):
        hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs)
92
93
94
95
96
97
98
99

        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size

        return hf_processor
100

101
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
102
    def get_num_image_tokens(
103
104
105
106
107
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
108
109
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
110

111
        base_feature_size = get_num_selected_vision_tokens(
112
            vision_encoder_info.get_num_image_tokens(
113
114
115
                image_width=image_width,
                image_height=image_height,
            ),
116
            hf_config.vision_feature_select_strategy,
117
        )
118
119
120
121

        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
            image_size=(image_height, image_width),
            grid_pinpoints=hf_config.image_grid_pinpoints,
122
            patch_size=vision_encoder_info.get_image_size(),
123
124
        )

125
126
127
128
129
130
        (
            unpadded_feature_size,
            newline_feature_size,
        ) = self._get_num_unpadded_features(
            original_height=image_height,
            original_width=image_width,
131
            npatches=vision_encoder_info.get_patch_grid_length(),
132
133
134
            num_patch_height=num_patch_height,
            num_patch_width=num_patch_width,
        )
135

136
        return unpadded_feature_size + newline_feature_size + base_feature_size
137

138
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
139
140
141
142
143
144
145
146
147
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
148
149
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
150

151
152
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
153

154
        if aspect_ratio > current_aspect_ratio:
155
            new_height = int(
156
157
                round(original_height * (current_width / original_width), 7)
            )
158
159
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
160
        else:
161
            new_width = int(
162
163
                round(original_width * (current_height / original_height), 7)
            )
164
165
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
166

167
168
        unpadded_features = current_height * current_width
        newline_features = current_height
169

170
171
        return (unpadded_features, newline_features)

172
173
    def get_image_size_with_most_features(self) -> ImageSize:
        hf_config = self.get_hf_config()
174
175

        largest_feature_size, largest_feature_pinpoint = 0, None
176
177
178
179
        for height, width in hf_config.image_grid_pinpoints:
            feat_size = self.get_num_image_tokens(
                image_width=width, image_height=height
            )
180
181
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
182
                largest_feature_pinpoint = ImageSize(width=width, height=height)
183
184
185
186

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

187
188
189
        return largest_feature_pinpoint


190
191
192
193
194
195
196
197
198
199
200
201
202
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)


class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError

203

204
class LlavaNextMultiModalProcessor(
205
206
    BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]
):
207
208
209
210
211
212
213
214
215
216
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )
217
218


219
220
221
222
223
224
@MULTIMODAL_REGISTRY.register_processor(
    LlavaNextMultiModalProcessor,
    info=LlavaNextProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
225
226
227
228
229
230
231
232
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "model.image_newline": "image_newline",
            "lm_head.": "language_model.lm_head.",
233
234
        }
    )
235

236
237
238
239
240
241
242
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

243
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
244
        super().__init__()
245
246
247
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
248

249
250
251
252
        vision_feature_layer = config.vision_feature_layer
        # Determine the layer up to which we will initialize the vision tower
        if isinstance(vision_feature_layer, int):
            vision_hidden_size = config.vision_config.hidden_size
253
            self.select_layers = None
254
255
256
        # Used for multimodal granite models to control encoder outputs
        elif isinstance(vision_feature_layer, (list, tuple)):
            vision_hidden_size = config.vision_config.hidden_size * len(
257
258
                vision_feature_layer
            )
259
            self.select_layers = vision_feature_layer
260
261
262
        else:
            raise TypeError(
                f"vision_layer_feature type: {type(vision_feature_layer)}"
263
264
                " is not supported"
            )
265

266
        self.config = config
267
        self.multimodal_config = multimodal_config
268

269
        # TODO: Optionally initializes this for supporting embeddings.
270
        self.vision_tower = init_vision_tower_for_llava(
271
272
273
            config,
            quant_config,
            require_post_norm=False,
274
275
276
            prefix=maybe_prefix(prefix, "vision_tower"),
        )
        self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
277
        self.multi_modal_projector = LlavaMultiModalProjector(
278
            vision_hidden_size=vision_hidden_size,
279
            text_hidden_size=config.text_config.hidden_size,
280
            projector_hidden_act=config.projector_hidden_act,
281
282
            multimodal_projector_bias=config.multimodal_projector_bias,
        )
283

284
        self.language_model = init_vllm_registered_model(
285
            vllm_config=vllm_config,
286
287
288
289
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

290
        self.make_empty_intermediate_tensors = (
291
292
            self.language_model.make_empty_intermediate_tensors
        )
293

294
    def _parse_and_validate_image_input(
295
296
        self, **kwargs: object
    ) -> Optional[LlavaNextImageInputs]:
297
298
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
299
        image_embeds = kwargs.pop("image_embeds", None)
300

301
        if pixel_values is None and image_embeds is None:
302
            return None
303

304
305
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
306
307
308
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )
309

310
            if not isinstance(image_sizes, (torch.Tensor, list)):
311
312
313
                raise ValueError(
                    f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
                )
314

315
            expected_h = expected_w = self.config.vision_config.image_size
316
317
            return LlavaNextImagePixelInputs(
                type="pixel_values",
318
319
320
321
322
                pixel_values=flatten_bn(pixel_values),
                image_sizes=flatten_bn(image_sizes, concat=True),
                resolve_bindings={
                    "h": expected_h,
                    "w": expected_w,
323
324
                },
            )
325
326
327

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
328
329
330
                raise ValueError(
                    f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
                )
331
332
333

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
334
                data=flatten_bn(image_embeds),
335
336
337
            )

        raise AssertionError("This line should be unreachable.")
338

339
340
341
342
343
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
344
345
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
346
347
348
349
        return vision_tower(
            pixel_values,
            select_layers=self.select_layers,
            feature_select_strategy=self.config.vision_feature_select_strategy,
Cyrus Leung's avatar
Cyrus Leung committed
350
351
        )

352
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
353
354
355
    def _merge_image_patch_embeddings(
        self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str
    ) -> torch.Tensor:
356
357
358
359
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
360
361
            height = width = (
                self.config.vision_config.image_size
362
                // self.config.vision_config.patch_size
363
            )
364
365
366
367

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
368
369
                    "The number of patches is not consistent with the image size."
                )
370
371
372
373

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

374
375
376
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

377
                # image_aspect_ratio == "anyres"
378
379
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
380
381
382
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
383
384
385
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
386
387
388
                other_patch_embeds = other_patch_embeds[:num_patches].view(
                    num_patch_height, num_patch_width, height, width, -1
                )
389
390

                if "unpad" in strategy:
391
392
393
394
395
396
397
398
399
400
401
402
403
404
                    other_patch_embeds = (
                        other_patch_embeds.permute(4, 0, 2, 1, 3)
                        .contiguous()
                        .flatten(1, 2)
                        .flatten(2, 3)
                    )
                    other_patch_embeds = unpad_image(
                        other_patch_embeds, (orig_height, orig_width)
                    )
                    other_patch_embeds = torch.cat(
                        (
                            other_patch_embeds,
                            self.image_newline[:, None, None]
                            .expand(*other_patch_embeds.shape[:-1], 1)
405
                            .to(other_patch_embeds.device),
406
407
408
409
410
411
                        ),
                        dim=-1,
                    )
                    other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose(
                        0, 1
                    )
412
                else:
413
414
415
                    other_patch_embeds = (
                        other_patch_embeds.permute(0, 2, 1, 3, 4)
                        .contiguous()
416
                        .flatten(0, 3)
417
                    )
418
419

                merged_patch_embeddings = torch.cat(
420
421
                    (base_patch_embeds, other_patch_embeds), dim=0
                )
422
423
424
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
425
426
427
428
429
430
                        (
                            base_patch_embeds,
                            self.image_newline[None].to(base_patch_embeds.device),
                        ),
                        dim=0,
                    )
431
432
433
434
435
436
437
438
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
439
440
        self,
        inputs: LlavaNextImagePixelInputs,
441
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
442
443
        assert self.vision_tower is not None

444
        pixel_values = inputs["pixel_values"]
445

446
447
448
449
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
450
451
                self.vision_tower, stacked_pixel_values
            )
452
            stacked_patch_embeddings = self.multi_modal_projector(
453
454
                stacked_image_features
            )
455

456
            return stacked_patch_embeddings.view(
457
458
                b, num_patches, *stacked_patch_embeddings.shape[1:]
            )
459
460
461

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
462
        stacked_image_features = self._image_pixels_to_features(
463
464
            self.vision_tower, stacked_pixel_values
        )
465

466
467
468
        return torch.split(
            self.multi_modal_projector(stacked_image_features), num_patches_per_batch
        )
469
470

    def _process_image_input(
471
472
        self,
        image_input: LlavaNextImageInputs,
473
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
474
475
476
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

477
        patch_embeddings = self._process_image_pixels(image_input)
478
479
480

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
481
            batch_size = len(image_input["data"])
482
            vision_config = self.config.vision_config
483
            default_height = default_width = vision_config.image_size
484
485
486
            image_sizes = torch.as_tensor(
                [[default_height, default_width] for _ in range(batch_size)]
            )
487

488
        return [
489
490
491
            self._merge_image_patch_embeddings(
                image_sizes[i], patch_features_batch, strategy="spatial_unpad"
            )
492
            for i, patch_features_batch in enumerate(patch_embeddings)
493
494
        ]

495
496
497
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

498
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
499
500
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
501
            return []
502
503
504
505
506
507
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
508
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
509
510
511
512
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
513
    ) -> torch.Tensor:
514
515
516
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().get_input_embeddings(input_ids)
517

518
        return super().get_input_embeddings(
519
            input_ids,
520
521
522
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
523
524
        )

525
526
527
528
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
529
        intermediate_tensors: Optional[IntermediateTensors] = None,
530
        inputs_embeds: Optional[torch.Tensor] = None,
531
        **kwargs: object,
532
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
533
        """Run forward pass for LlaVA-NeXT.
534
535
536

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
537

538
        Concretely, consider a text prompt:
539
540
541
542
543
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

544
        Tokenizer outputs:
545
546
547
548
549
550
551
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
552
        before they are inputted to the model, so the input processor prepends
553
554
555
556
557
558
559
560
561
562
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
samzong's avatar
samzong committed
563
564
        is given by [`LlavaNextProcessingInfo.get_num_image_tokens`][vllm.\
model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
565
566
567
568
569
570
571

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
572
573
574
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
575

576
        Info:
samzong's avatar
samzong committed
577
            [`LlavaNextImageInputs`][vllm.model_executor.models.llava_next.LlavaNextImageInputs]
578
        """
579
580
        if intermediate_tensors is not None:
            inputs_embeds = None
581

582
583
584
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
585
586
        return hidden_states

587
588
589
590
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
591
        return self.language_model.compute_logits(hidden_states)
592

593
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
594
        loader = AutoWeightsLoader(self)
595
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)