llava_next.py 23.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
5
6
from collections.abc import Iterable, Mapping
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
                    Union)
7
8
9

import torch
import torch.nn as nn
10
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
11
12
13
14
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

15
from vllm.config import VllmConfig
16
from vllm.model_executor.sampling_metadata import SamplingMetadata
17
from vllm.multimodal import MULTIMODAL_REGISTRY
18
from vllm.multimodal.inputs import MultiModalFieldConfig
19
from vllm.multimodal.parse import ImageSize
20
from vllm.sequence import IntermediateTensors
21

22
from .clip import CLIPVisionModel
23
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
24
25
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
                    LlavaDummyInputsBuilder, LlavaLikeConfig,
26
                    LlavaMultiModalProjector, init_vision_tower_for_llava)
27
from .siglip import SiglipVisionModel
Cyrus Leung's avatar
Cyrus Leung committed
28
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
29
                    init_vllm_registered_model, maybe_prefix)
30
31
32
33


class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
34
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
35
    """
36
37
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
38

39
40
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
41
    """
42
43

    image_sizes: NotRequired[torch.Tensor]
44
    """
45
    Shape: `(batch_size * num_images, 2)`
46
47
48

    This should be in `(height, width)` format.
    """
49
50


51
52
53
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
54
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
55
56
57
58
59
60
61

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
62
63


64
65
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
    image_grid_pinpoints: Final[list[list[int]]]
66

67

68
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
69

70
    def get_hf_config(self) -> LlavaNextLikeConfig:
71
        return self.ctx.get_hf_config(LlavaNextConfig)
72

73
74
    def get_hf_processor(self, **kwargs: object):
        hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs)
75
76
77
78
79
80
81
82

        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size

        return hf_processor
83

84
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
85
    def get_num_image_tokens(
86
87
88
89
90
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
91
92
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
93
94
95

        base_feature_size = self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
96
            vision_encoder_info.get_num_image_tokens(
97
98
99
                image_width=image_width,
                image_height=image_height,
            ),
100
        )
101
102
103
104

        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
            image_size=(image_height, image_width),
            grid_pinpoints=hf_config.image_grid_pinpoints,
105
            patch_size=vision_encoder_info.get_image_size(),
106
107
        )

108
109
110
111
112
113
        (
            unpadded_feature_size,
            newline_feature_size,
        ) = self._get_num_unpadded_features(
            original_height=image_height,
            original_width=image_width,
114
            npatches=vision_encoder_info.get_patch_grid_length(),
115
116
117
            num_patch_height=num_patch_height,
            num_patch_width=num_patch_width,
        )
118

119
        return unpadded_feature_size + newline_feature_size + base_feature_size
120

121
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
122
123
124
125
126
127
128
129
130
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
131
132
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
133

134
135
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
136

137
        if aspect_ratio > current_aspect_ratio:
138
139
            new_height = int(
                round(original_height * (current_width / original_width), 7))
140
141
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
142
        else:
143
144
            new_width = int(
                round(original_width * (current_height / original_height), 7))
145
146
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
147

148
149
        unpadded_features = current_height * current_width
        newline_features = current_height
150

151
152
        return (unpadded_features, newline_features)

153
154
    def get_image_size_with_most_features(self) -> ImageSize:
        hf_config = self.get_hf_config()
155
156
157

        largest_feature_size, largest_feature_pinpoint = 0, None
        for (height, width) in hf_config.image_grid_pinpoints:
158
159
            feat_size = self.get_num_image_tokens(image_width=width,
                                                  image_height=height)
160
161
162
163
164
165
166
167
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

168
169
170
        return largest_feature_pinpoint


171
172
173
174
175
176
177
178
179
180
181
182
183
184
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)


class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError

185

186
187
class LlavaNextMultiModalProcessor(
        BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
188
189
190
191
192
193
194
195
196
197
198

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )
199
200


201
202
203
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
                                        info=LlavaNextProcessingInfo,
                                        dummy_inputs=LlavaDummyInputsBuilder)
204
205
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
206

207
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
208
        super().__init__()
209
210
211
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        vision_feature_layer = config.vision_feature_layer
        # Determine the layer up to which we will initialize the vision tower
        if isinstance(vision_feature_layer, int):
            vision_hidden_size = config.vision_config.hidden_size
            self.feature_sample_layers = None
        # Used for multimodal granite models to control encoder outputs
        elif isinstance(vision_feature_layer, (list, tuple)):
            vision_hidden_size = config.vision_config.hidden_size * len(
                vision_feature_layer)
            self.feature_sample_layers = vision_feature_layer
        else:
            raise TypeError(
                f"vision_layer_feature type: {type(vision_feature_layer)}"
                " is not supported")

228
        self.config = config
229
        self.multimodal_config = multimodal_config
230

231
        # TODO: Optionally initializes this for supporting embeddings.
232
        self.vision_tower = init_vision_tower_for_llava(
233
234
235
            config,
            quant_config,
            require_post_norm=False,
236
            prefix=maybe_prefix(prefix, "vision_tower"))
237
238
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
239
        self.multi_modal_projector = LlavaMultiModalProjector(
240
            vision_hidden_size=vision_hidden_size,
241
            text_hidden_size=config.text_config.hidden_size,
242
243
            projector_hidden_act=config.projector_hidden_act,
            multimodal_projector_bias=config.multimodal_projector_bias)
244

245
        self.language_model = init_vllm_registered_model(
246
            vllm_config=vllm_config,
247
248
249
250
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

251
252
253
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

254
    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
255
256
257
258
259
260
261
262
263
264
265
266
267
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
268
269
270

        return data

271
    def _validate_pixel_values(
272
273
        self, data: Union[torch.Tensor, list[torch.Tensor]]
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
274

275
276
277
278
279
280
281
282
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
283
                raise ValueError(
284
                    "The expected shape of pixel values per image per batch "
285
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
286

287
288
        for d in data:
            _validate_shape(d)
289
290
291

        return data

292
    def _parse_and_validate_image_input(
293
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
294
295
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
296
        image_embeds = kwargs.pop("image_embeds", None)
297

298
        if pixel_values is None and image_embeds is None:
299
            return None
300

301
302
303
304
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
305

306
            if not isinstance(image_sizes, (torch.Tensor, list)):
307
308
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
309

310
311
            return LlavaNextImagePixelInputs(
                type="pixel_values",
312
313
                pixel_values=self._validate_pixel_values(
                    flatten_bn(pixel_values)),
314
315
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
316
317
318
319
320
321
322
323
324
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
325
                data=flatten_bn(image_embeds),
326
327
328
            )

        raise AssertionError("This line should be unreachable.")
329

Cyrus Leung's avatar
Cyrus Leung committed
330
331
332
333
334
335
336
337
338
339
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

340
341
342
343
344
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
345

346
347
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
348
349
        image_features = vision_tower(
            pixel_values, feature_sample_layers=self.feature_sample_layers)
Cyrus Leung's avatar
Cyrus Leung committed
350
351
352
353
354
355

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

356
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

376
377
378
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

379
                # image_aspect_ratio == "anyres"
380
381
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
382
383
384
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
385
386
387
388
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
389
                    .view(num_patch_height, num_patch_width, height, width, -1)
390
391
392
393
394
395

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
396
                                                     (orig_height, orig_width))
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
427
428
        self,
        inputs: LlavaNextImagePixelInputs,
429
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
430
431
        assert self.vision_tower is not None

432
        pixel_values = inputs["pixel_values"]
433

434
435
436
437
438
439
440
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
441

442
443
444
445
446
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
447
448
449
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

450
451
        return torch.split(self.multi_modal_projector(stacked_image_features),
                           num_patches_per_batch)
452
453

    def _process_image_input(
454
455
        self,
        image_input: LlavaNextImageInputs,
456
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
457
458
459
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

460
        patch_embeddings = self._process_image_pixels(image_input)
461
462
463

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
464
            batch_size = len(image_input["data"])
465
            vision_config = self.config.vision_config
466
467
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
468
469
                                           for _ in range(batch_size)])

470
        return [
471
            self._merge_image_patch_embeddings(image_sizes[i],
472
                                               patch_features_batch,
473
                                               strategy="spatial_unpad")
474
            for i, patch_features_batch in enumerate(patch_embeddings)
475
476
        ]

477
478
479
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

480
    def get_multimodal_embeddings(
481
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
482
483
484
485
486
487
488
489
490
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
491
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
492
493
494
495
496
497
498
499
500
501
502
503
504
    ) -> torch.Tensor:

        if multimodal_embeddings is None:
            return self.language_model.get_input_embeddings(input_ids)

        inputs_embeds = embed_multimodal(
            input_ids,
            self.config.image_token_index,
            self.language_model.model.get_input_embeddings,
            multimodal_embeddings,
        )
        return inputs_embeds

505
506
507
508
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
509
        intermediate_tensors: Optional[IntermediateTensors] = None,
510
        inputs_embeds: Optional[torch.Tensor] = None,
511
        **kwargs: object,
512
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
513
        """Run forward pass for LlaVA-NeXT.
514
515
516

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
517

518
        Concretely, consider a text prompt:
519
520
521
522
523
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

524
        Tokenizer outputs:
525
526
527
528
529
530
531
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
532
        before they are inputted to the model, so the input processor prepends
533
534
535
536
537
538
539
540
541
542
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
543
        is given by [get_llava_next_image_feature_size][].
544
545
546
547
548
549
550

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
551
            pixel_values: The pixels in each grid patch for each input image.
552
            image_sizes: The original `(height, width)` for each input image.
553

554
555
        Info:
            [LlavaNextImageInputs][]
556
        """
557
558
        if intermediate_tensors is not None:
            inputs_embeds = None
559

560
561
562
563
564
565
566
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
567

568
569
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
570
                                                  intermediate_tensors,
571
                                                  inputs_embeds=inputs_embeds)
572
573
        return hidden_states

574
575
576
577
578
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
579
580
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
581

582
583
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
584
        loader = AutoWeightsLoader(self)
585
        return loader.load_weights(weights)