llava_next.py 23.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from functools import cached_property
5
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
6
                    Protocol, Set, Tuple, TypedDict, TypeVar, Union)
7
8
9

import torch
import torch.nn as nn
10
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
11
12
13
14
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

15
from vllm.config import VllmConfig
Joe Runde's avatar
Joe Runde committed
16
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
17
from vllm.model_executor.sampling_metadata import SamplingMetadata
18
from vllm.multimodal import MULTIMODAL_REGISTRY
19
20
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize
21
from vllm.sequence import IntermediateTensors
22

23
from .clip import CLIPVisionModel
24
from .interfaces import SupportsMultiModal, SupportsPP
25
26
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
                    LlavaDummyInputsBuilder, LlavaLikeConfig,
27
                    LlavaMultiModalProjector, init_vision_tower_for_llava)
28
from .siglip import SiglipVisionModel
Cyrus Leung's avatar
Cyrus Leung committed
29
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
30
                    init_vllm_registered_model, maybe_prefix)
31
32
33
34


class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
35
    data: Union[torch.Tensor, List[torch.Tensor]]
36
    """
37
38
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
39

40
41
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
42
    """
43
44

    image_sizes: NotRequired[torch.Tensor]
45
    """
46
    Shape: `(batch_size * num_images, 2)`
47
48
49

    This should be in `(height, width)` format.
    """
50
51


52
53
54
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
55
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
56
57
58
59
60
61
62

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
63
64


65
66
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
    image_grid_pinpoints: Final[list[list[int]]]
67

68

69
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
70

71
    def get_hf_config(self) -> LlavaNextLikeConfig:
72
        return self.ctx.get_hf_config(LlavaNextConfig)
73

74
75
    def get_hf_processor(self, **kwargs: object):
        hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs)
76
77
78
79
80
81
82
83

        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size

        return hf_processor
84

85
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
86
    def get_num_image_tokens(
87
88
89
90
91
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
92
93
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
94
95
96

        base_feature_size = self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
97
            vision_encoder_info.get_num_image_tokens(
98
99
100
                image_width=image_width,
                image_height=image_height,
            ),
101
        )
102
103
104
105

        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
            image_size=(image_height, image_width),
            grid_pinpoints=hf_config.image_grid_pinpoints,
106
            patch_size=vision_encoder_info.get_image_size(),
107
108
        )

109
110
111
112
113
114
        (
            unpadded_feature_size,
            newline_feature_size,
        ) = self._get_num_unpadded_features(
            original_height=image_height,
            original_width=image_width,
115
            npatches=vision_encoder_info.get_patch_grid_length(),
116
117
118
            num_patch_height=num_patch_height,
            num_patch_width=num_patch_width,
        )
119

120
        return unpadded_feature_size + newline_feature_size + base_feature_size
121

122
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
123
124
125
126
127
128
129
130
131
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
132
133
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
134

135
136
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
137

138
139
140
141
        if aspect_ratio > current_aspect_ratio:
            new_height = (original_height * current_width) // original_width
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
142
        else:
143
144
145
            new_width = (original_width * current_height) // original_height
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
146

147
148
        unpadded_features = current_height * current_width
        newline_features = current_height
149

150
151
        return (unpadded_features, newline_features)

152
153
    def get_image_size_with_most_features(self) -> ImageSize:
        hf_config = self.get_hf_config()
154
155
156

        largest_feature_size, largest_feature_pinpoint = 0, None
        for (height, width) in hf_config.image_grid_pinpoints:
157
158
            feat_size = self.get_num_image_tokens(image_width=width,
                                                  image_height=height)
159
160
161
162
163
164
165
166
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

167
168
169
        return largest_feature_pinpoint


170
171
172
173
174
175
176
177
178
179
180
181
182
183
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)


class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError

184

185
186
class LlavaNextMultiModalProcessor(
        BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
187
188
189
190
191
192
193
194
195
196
197

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )
198
199


200
201
202
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
                                        info=LlavaNextProcessingInfo,
                                        dummy_inputs=LlavaDummyInputsBuilder)
203
204
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
205

206
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
207
        super().__init__()
208
209
210
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
211

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        vision_feature_layer = config.vision_feature_layer
        # Determine the layer up to which we will initialize the vision tower
        if isinstance(vision_feature_layer, int):
            vision_hidden_size = config.vision_config.hidden_size
            self.feature_sample_layers = None
        # Used for multimodal granite models to control encoder outputs
        elif isinstance(vision_feature_layer, (list, tuple)):
            vision_hidden_size = config.vision_config.hidden_size * len(
                vision_feature_layer)
            self.feature_sample_layers = vision_feature_layer
        else:
            raise TypeError(
                f"vision_layer_feature type: {type(vision_feature_layer)}"
                " is not supported")

227
        self.config = config
228
        self.multimodal_config = multimodal_config
229

230
        # TODO: Optionally initializes this for supporting embeddings.
231
        self.vision_tower = init_vision_tower_for_llava(
232
233
234
            config,
            quant_config,
            require_post_norm=False,
235
            prefix=maybe_prefix(prefix, "vision_tower"))
236
237
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
238
        self.multi_modal_projector = LlavaMultiModalProjector(
239
            vision_hidden_size=vision_hidden_size,
240
            text_hidden_size=config.text_config.hidden_size,
241
242
            projector_hidden_act=config.projector_hidden_act,
            multimodal_projector_bias=config.multimodal_projector_bias)
243

244
        self.language_model = init_vllm_registered_model(
245
            vllm_config=vllm_config,
246
247
248
249
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

250
251
252
253
254
255
256
257
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
258
        return get_sampler()
259
260

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
261
262
263
264
265
266
267
268
269
270
271
272
273
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
274
275
276

        return data

277
278
279
280
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

281
282
283
284
285
286
287
288
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
289
                raise ValueError(
290
                    "The expected shape of pixel values per image per batch "
291
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
292

293
294
        for d in data:
            _validate_shape(d)
295
296
297

        return data

298
    def _parse_and_validate_image_input(
299
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
300
301
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
302
        image_embeds = kwargs.pop("image_embeds", None)
303

304
        if pixel_values is None and image_embeds is None:
305
            return None
306

307
308
309
310
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
311

312
            if not isinstance(image_sizes, (torch.Tensor, list)):
313
314
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
315

316
317
            return LlavaNextImagePixelInputs(
                type="pixel_values",
318
319
320
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
321
322
323
324
325
326
327
328
329
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
330
                data=flatten_bn(image_embeds),
331
332
333
            )

        raise AssertionError("This line should be unreachable.")
334

Cyrus Leung's avatar
Cyrus Leung committed
335
336
337
338
339
340
341
342
343
344
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

345
346
347
348
349
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
350

351
352
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
353
354
        image_features = vision_tower(
            pixel_values, feature_sample_layers=self.feature_sample_layers)
Cyrus Leung's avatar
Cyrus Leung committed
355
356
357
358
359
360

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

361
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

381
382
383
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

384
                # image_aspect_ratio == "anyres"
385
386
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
387
388
389
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
390
391
392
393
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
394
                    .view(num_patch_height, num_patch_width, height, width, -1)
395
396
397
398
399
400

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
401
                                                     (orig_height, orig_width))
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
432
433
        self,
        inputs: LlavaNextImagePixelInputs,
434
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
435
436
437
438
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

439
440
441
442
443
444
445
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
446

447
448
449
450
451
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
452
453
454
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

455
456
        return torch.split(self.multi_modal_projector(stacked_image_features),
                           num_patches_per_batch)
457
458

    def _process_image_input(
459
460
461
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
462
463
464
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

465
        patch_embeddings = self._process_image_pixels(image_input)
466
467
468

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
469
            batch_size = len(image_input["data"])
470
            vision_config = self.config.vision_config
471
472
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
473
474
                                           for _ in range(batch_size)])

475
        return [
476
            self._merge_image_patch_embeddings(image_sizes[i],
477
                                               patch_features_batch,
478
                                               strategy="spatial_unpad")
479
            for i, patch_features_batch in enumerate(patch_embeddings)
480
481
        ]

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:

        if multimodal_embeddings is None:
            return self.language_model.get_input_embeddings(input_ids)

        inputs_embeds = embed_multimodal(
            input_ids,
            self.config.image_token_index,
            self.language_model.model.get_input_embeddings,
            multimodal_embeddings,
        )
        return inputs_embeds

506
507
508
509
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
510
        intermediate_tensors: Optional[IntermediateTensors] = None,
511
        inputs_embeds: Optional[torch.Tensor] = None,
512
        **kwargs: object,
513
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
514
        """Run forward pass for LlaVA-NeXT.
515
516
517

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
518

519
        Concretely, consider a text prompt:
520
521
522
523
524
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

525
        Tokenizer outputs:
526
527
528
529
530
531
532
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
533
        before they are inputted to the model, so the input processor prepends
534
535
536
537
538
539
540
541
542
543
544
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
545
546
547
548
549
550
551

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
552
            pixel_values: The pixels in each grid patch for each input image.
553
            image_sizes: The original `(height, width)` for each input image.
554

Cyrus Leung's avatar
Cyrus Leung committed
555
        See also:
556
            :class:`LlavaNextImageInputs`
557
        """
558
559
        if intermediate_tensors is not None:
            inputs_embeds = None
560

561
562
563
564
565
566
567
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
568

569
570
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
571
                                                  intermediate_tensors,
572
                                                  inputs_embeds=inputs_embeds)
573
574
        return hidden_states

575
576
577
578
579
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
580
581
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
582
583
584
585
586
587

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
588
        return self.language_model.sample(logits, sampling_metadata)
589

590
591
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
592
        loader = AutoWeightsLoader(self)
593
        return loader.load_weights(weights)