"vllm/model_executor/models/llava_onevision.py" did not exist on "f4fc7337bfaf5f10b8da4ba547e4009179348a26"
llava_next.py 22.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping
6
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
7
                    Union)
8
9
10

import torch
import torch.nn as nn
11
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
12
13
14
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)

15
from vllm.config import VllmConfig
16
from vllm.multimodal import MULTIMODAL_REGISTRY
17
from vllm.multimodal.inputs import MultiModalFieldConfig
18
from vllm.multimodal.parse import ImageSize
19
from vllm.sequence import IntermediateTensors
20
from vllm.utils.tensor_schema import TensorSchema, TensorShape
21

22
from .clip import CLIPVisionModel
23
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
24
25
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
                    LlavaDummyInputsBuilder, LlavaLikeConfig,
26
                    LlavaMultiModalProjector, init_vision_tower_for_llava)
27
from .siglip import SiglipVisionModel
28
29
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix)
30
31


32
class LlavaNextImagePixelInputs(TensorSchema):
33
    """
34
35
36
37
38
39
40
    Dimensions:
        - bn: Batch size * number of images
        - np: Number of patches + 1
        - c: Number of channels (3)
        - h: Height
        - w: Width
    
41
42
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
43
    """
44
45
46
47
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})]
48

49
50
    image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
    # This should be in `(height, width)` format.
51

52

53
class LlavaNextImageEmbeddingInputs(TensorSchema):
54
    """
55
56
57
58
59
60
61
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
62
63
64
65


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
66
67


68
69
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
    image_grid_pinpoints: Final[list[list[int]]]
70

71

72
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
73

74
    def get_hf_config(self) -> LlavaNextLikeConfig:
75
        return self.ctx.get_hf_config(LlavaNextConfig)
76

77
78
    def get_hf_processor(self, **kwargs: object):
        hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs)
79
80
81
82
83
84
85
86

        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size

        return hf_processor
87

88
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
89
    def get_num_image_tokens(
90
91
92
93
94
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
95
96
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
97
98
99

        base_feature_size = self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
100
            vision_encoder_info.get_num_image_tokens(
101
102
103
                image_width=image_width,
                image_height=image_height,
            ),
104
        )
105
106
107
108

        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
            image_size=(image_height, image_width),
            grid_pinpoints=hf_config.image_grid_pinpoints,
109
            patch_size=vision_encoder_info.get_image_size(),
110
111
        )

112
113
114
115
116
117
        (
            unpadded_feature_size,
            newline_feature_size,
        ) = self._get_num_unpadded_features(
            original_height=image_height,
            original_width=image_width,
118
            npatches=vision_encoder_info.get_patch_grid_length(),
119
120
121
            num_patch_height=num_patch_height,
            num_patch_width=num_patch_width,
        )
122

123
        return unpadded_feature_size + newline_feature_size + base_feature_size
124

125
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
126
127
128
129
130
131
132
133
134
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
135
136
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
137

138
139
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
140

141
        if aspect_ratio > current_aspect_ratio:
142
143
            new_height = int(
                round(original_height * (current_width / original_width), 7))
144
145
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
146
        else:
147
148
            new_width = int(
                round(original_width * (current_height / original_height), 7))
149
150
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
151

152
153
        unpadded_features = current_height * current_width
        newline_features = current_height
154

155
156
        return (unpadded_features, newline_features)

157
158
    def get_image_size_with_most_features(self) -> ImageSize:
        hf_config = self.get_hf_config()
159
160
161

        largest_feature_size, largest_feature_pinpoint = 0, None
        for (height, width) in hf_config.image_grid_pinpoints:
162
163
            feat_size = self.get_num_image_tokens(image_width=width,
                                                  image_height=height)
164
165
166
167
168
169
170
171
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

172
173
174
        return largest_feature_pinpoint


175
176
177
178
179
180
181
182
183
184
185
186
187
188
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)


class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError

189

190
191
class LlavaNextMultiModalProcessor(
        BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
192
193
194
195
196
197
198
199
200
201
202

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )
203
204


205
206
207
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
                                        info=LlavaNextProcessingInfo,
                                        dummy_inputs=LlavaDummyInputsBuilder)
208
209
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
210

211
212
213
214
215
216
217
218
219
220
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "model.image_newline": "image_newline",
            "lm_head.": "language_model.lm_head.",
        })

221
222
223
224
225
226
227
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

228
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
229
        super().__init__()
230
231
232
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
233

234
235
236
237
        vision_feature_layer = config.vision_feature_layer
        # Determine the layer up to which we will initialize the vision tower
        if isinstance(vision_feature_layer, int):
            vision_hidden_size = config.vision_config.hidden_size
238
            self.select_layers = None
239
240
241
242
        # Used for multimodal granite models to control encoder outputs
        elif isinstance(vision_feature_layer, (list, tuple)):
            vision_hidden_size = config.vision_config.hidden_size * len(
                vision_feature_layer)
243
            self.select_layers = vision_feature_layer
244
245
246
247
248
        else:
            raise TypeError(
                f"vision_layer_feature type: {type(vision_feature_layer)}"
                " is not supported")

249
        self.config = config
250
        self.multimodal_config = multimodal_config
251

252
        # TODO: Optionally initializes this for supporting embeddings.
253
        self.vision_tower = init_vision_tower_for_llava(
254
255
256
            config,
            quant_config,
            require_post_norm=False,
257
            prefix=maybe_prefix(prefix, "vision_tower"))
258
259
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
260
        self.multi_modal_projector = LlavaMultiModalProjector(
261
            vision_hidden_size=vision_hidden_size,
262
            text_hidden_size=config.text_config.hidden_size,
263
264
            projector_hidden_act=config.projector_hidden_act,
            multimodal_projector_bias=config.multimodal_projector_bias)
265

266
        self.language_model = init_vllm_registered_model(
267
            vllm_config=vllm_config,
268
269
270
271
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

272
273
274
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

275
    def _parse_and_validate_image_input(
276
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
277
278
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
279
        image_embeds = kwargs.pop("image_embeds", None)
280

281
        if pixel_values is None and image_embeds is None:
282
            return None
283

284
285
286
287
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
288

289
            if not isinstance(image_sizes, (torch.Tensor, list)):
290
291
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
292

293
            expected_h = expected_w = self.config.vision_config.image_size
294
295
            return LlavaNextImagePixelInputs(
                type="pixel_values",
296
297
298
299
300
301
                pixel_values=flatten_bn(pixel_values),
                image_sizes=flatten_bn(image_sizes, concat=True),
                resolve_bindings={
                    "h": expected_h,
                    "w": expected_w,
                })
302
303
304
305
306
307
308
309

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
310
                data=flatten_bn(image_embeds),
311
312
313
            )

        raise AssertionError("This line should be unreachable.")
314

315
316
317
318
319
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
320
321
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
322
323
324
325
        return vision_tower(
            pixel_values,
            select_layers=self.select_layers,
            feature_select_strategy=self.config.vision_feature_select_strategy,
Cyrus Leung's avatar
Cyrus Leung committed
326
327
        )

328
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

348
349
350
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

351
                # image_aspect_ratio == "anyres"
352
353
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
354
355
356
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
357
358
359
360
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
361
                    .view(num_patch_height, num_patch_width, height, width, -1)
362
363
364
365
366
367

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
368
                                                     (orig_height, orig_width))
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
399
400
        self,
        inputs: LlavaNextImagePixelInputs,
401
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
402
403
        assert self.vision_tower is not None

404
        pixel_values = inputs["pixel_values"]
405

406
407
408
409
410
411
412
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
413

414
415
416
417
418
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
419
420
421
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

422
423
        return torch.split(self.multi_modal_projector(stacked_image_features),
                           num_patches_per_batch)
424
425

    def _process_image_input(
426
427
        self,
        image_input: LlavaNextImageInputs,
428
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
429
430
431
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

432
        patch_embeddings = self._process_image_pixels(image_input)
433
434
435

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
436
            batch_size = len(image_input["data"])
437
            vision_config = self.config.vision_config
438
439
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
440
441
                                           for _ in range(batch_size)])

442
        return [
443
            self._merge_image_patch_embeddings(image_sizes[i],
444
                                               patch_features_batch,
445
                                               strategy="spatial_unpad")
446
            for i, patch_features_batch in enumerate(patch_embeddings)
447
448
        ]

449
450
451
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

452
453
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
454
455
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
456
            return []
457
458
459
460
461
462
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
463
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
464
465
466
467
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
468
    ) -> torch.Tensor:
469
470
471
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().get_input_embeddings(input_ids)
472

473
        return super().get_input_embeddings(
474
            input_ids,
475
476
477
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
478
479
        )

480
481
482
483
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
484
        intermediate_tensors: Optional[IntermediateTensors] = None,
485
        inputs_embeds: Optional[torch.Tensor] = None,
486
        **kwargs: object,
487
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
488
        """Run forward pass for LlaVA-NeXT.
489
490
491

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
492

493
        Concretely, consider a text prompt:
494
495
496
497
498
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

499
        Tokenizer outputs:
500
501
502
503
504
505
506
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
507
        before they are inputted to the model, so the input processor prepends
508
509
510
511
512
513
514
515
516
517
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
samzong's avatar
samzong committed
518
519
        is given by [`LlavaNextProcessingInfo.get_num_image_tokens`][vllm.\
model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
520
521
522
523
524
525
526

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
527
528
529
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
530

531
        Info:
samzong's avatar
samzong committed
532
            [`LlavaNextImageInputs`][vllm.model_executor.models.llava_next.LlavaNextImageInputs]
533
        """
534
535
        if intermediate_tensors is not None:
            inputs_embeds = None
536

537
538
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
539
                                                  intermediate_tensors,
540
                                                  inputs_embeds=inputs_embeds)
541
542
        return hidden_states

543
544
545
546
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
547
        return self.language_model.compute_logits(hidden_states)
548

549
550
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
551
        loader = AutoWeightsLoader(self)
552
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)