llava_next.py 22.9 KB
Newer Older
1
from functools import cached_property
2
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
3
                    TypedDict, Union)
4
5
6

import torch
import torch.nn as nn
7
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
8
9
10
11
12
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
13
from vllm.config import VllmConfig
Joe Runde's avatar
Joe Runde committed
14
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
15
from vllm.model_executor.sampling_metadata import SamplingMetadata
16
from vllm.multimodal import MULTIMODAL_REGISTRY
17
18
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize
19
from vllm.sequence import IntermediateTensors
20

21
from .clip import CLIPVisionModel
22
from .interfaces import SupportsMultiModal, SupportsPP
23
24
25
from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector,
                    init_vision_tower_for_llava)
from .siglip import SiglipVisionModel
Cyrus Leung's avatar
Cyrus Leung committed
26
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
27
                    init_vllm_registered_model, maybe_prefix)
28
29
30
31


class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
32
    data: Union[torch.Tensor, List[torch.Tensor]]
33
    """
34
35
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
36

37
38
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
39
    """
40
41

    image_sizes: NotRequired[torch.Tensor]
42
    """
43
    Shape: `(batch_size * num_images, 2)`
44
45
46

    This should be in `(height, width)` format.
    """
47
48


49
50
51
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
52
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
53
54
55
56
57
58
59

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
60
61


62
class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
63

64
65
    def _get_hf_config(self) -> LlavaNextConfig:
        return self.ctx.get_hf_config(LlavaNextConfig)
66

67
68
    def _get_hf_processor(self) -> LlavaNextProcessor:
        return self.ctx.get_hf_processor(LlavaNextProcessor)
69

70
71
72
73
74
75
76
77
78
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
79
80
        )

81
82
83
    def _get_image_token(self) -> str:
        return self._get_hf_processor().image_token

84
85
86
    def _get_max_image_tokens(self) -> int:
        largest_feature_size, _ = self._get_pinpoint_with_most_features()
        return largest_feature_size
87

88
89
90
    def _get_dummy_image_size(self) -> ImageSize:
        _, pinpoint = self._get_pinpoint_with_most_features()
        return pinpoint
91

92
93
94
95
96
97
98
99
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
    def _get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self._get_hf_config()
100
        vision_encoder_info = self._vision_encoder_info
101
102
103

        base_feature_size = self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
104
            vision_encoder_info.get_num_image_tokens(
105
106
107
                image_width=image_width,
                image_height=image_height,
            ),
108
        )
109
110
111
112

        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
            image_size=(image_height, image_width),
            grid_pinpoints=hf_config.image_grid_pinpoints,
113
            patch_size=vision_encoder_info.get_image_size(),
114
115
        )

116
117
118
119
120
121
        (
            unpadded_feature_size,
            newline_feature_size,
        ) = self._get_num_unpadded_features(
            original_height=image_height,
            original_width=image_width,
122
            npatches=vision_encoder_info.get_patch_grid_length(),
123
124
125
            num_patch_height=num_patch_height,
            num_patch_width=num_patch_width,
        )
126

127
        return unpadded_feature_size + newline_feature_size + base_feature_size
128

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width

        original_aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height

        if original_aspect_ratio > current_aspect_ratio:
            scale_factor = current_width / original_width
            new_height = int(original_height * scale_factor)
            padding = (current_height - new_height) // 2
            current_height -= 2 * padding
        else:
            scale_factor = current_height / original_height
            new_width = int(original_width * scale_factor)
            padding = (current_width - new_width) // 2
            current_width -= 2 * padding

        unpadded_features = current_height * current_width
        newline_features = current_height
158

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        return (unpadded_features, newline_features)

    def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:
        """
        Get the grid pinpoint with the most features and
        the corresponding feature size.
        """
        hf_config = self._get_hf_config()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for (height, width) in hf_config.image_grid_pinpoints:
            feat_size = self._get_num_image_tokens(image_width=width,
                                                   image_height=height)
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_size, largest_feature_pinpoint


@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
184
185
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
186

187
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
188
        super().__init__()
189
190
191
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        vision_feature_layer = config.vision_feature_layer
        # Determine the layer up to which we will initialize the vision tower
        if isinstance(vision_feature_layer, int):
            vision_hidden_size = config.vision_config.hidden_size
            self.feature_sample_layers = None
        # Used for multimodal granite models to control encoder outputs
        elif isinstance(vision_feature_layer, (list, tuple)):
            vision_hidden_size = config.vision_config.hidden_size * len(
                vision_feature_layer)
            self.feature_sample_layers = vision_feature_layer
        else:
            raise TypeError(
                f"vision_layer_feature type: {type(vision_feature_layer)}"
                " is not supported")

208
        self.config = config
209
        self.multimodal_config = multimodal_config
210

211
        # TODO: Optionally initializes this for supporting embeddings.
212
        self.vision_tower = init_vision_tower_for_llava(
213
214
215
            config,
            quant_config,
            require_post_norm=False,
216
            prefix=maybe_prefix(prefix, "vision_tower"))
217
218
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
219
        self.multi_modal_projector = LlavaMultiModalProjector(
220
            vision_hidden_size=vision_hidden_size,
221
222
223
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

224
        self.language_model = init_vllm_registered_model(
225
            vllm_config=vllm_config,
226
227
228
229
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

230
231
232
233
234
235
236
237
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
238
        return get_sampler()
239
240

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
241
242
243
244
245
246
247
248
249
250
251
252
253
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
254
255
256

        return data

257
258
259
260
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

261
262
263
264
265
266
267
268
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
269
                raise ValueError(
270
                    "The expected shape of pixel values per image per batch "
271
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
272

273
274
        for d in data:
            _validate_shape(d)
275
276
277

        return data

278
    def _parse_and_validate_image_input(
279
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
280
281
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
282
        image_embeds = kwargs.pop("image_embeds", None)
283

284
        if pixel_values is None and image_embeds is None:
285
            return None
286

287
288
289
290
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
291

292
            if not isinstance(image_sizes, (torch.Tensor, list)):
293
294
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
295

296
297
            return LlavaNextImagePixelInputs(
                type="pixel_values",
298
299
300
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
301
302
303
304
305
306
307
308
309
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
310
                data=flatten_bn(image_embeds),
311
312
313
            )

        raise AssertionError("This line should be unreachable.")
314

Cyrus Leung's avatar
Cyrus Leung committed
315
316
317
318
319
320
321
322
323
324
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

325
326
327
328
329
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
330

331
332
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
333
334
        image_features = vision_tower(
            pixel_values, feature_sample_layers=self.feature_sample_layers)
Cyrus Leung's avatar
Cyrus Leung committed
335
336
337
338
339
340

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

341
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

361
362
363
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

364
                # image_aspect_ratio == "anyres"
365
366
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
367
368
369
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
370
371
372
373
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
374
                    .view(num_patch_height, num_patch_width, height, width, -1)
375
376
377
378
379
380

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
381
                                                     (orig_height, orig_width))
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
412
413
        self,
        inputs: LlavaNextImagePixelInputs,
414
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
415
416
417
418
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

419
420
421
422
423
424
425
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
426

427
428
429
430
431
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
432
433
434
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

435
436
        return torch.split(self.multi_modal_projector(stacked_image_features),
                           num_patches_per_batch)
437
438

    def _process_image_input(
439
440
441
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
442
443
444
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

445
        patch_embeddings = self._process_image_pixels(image_input)
446
447
448

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
449
            batch_size = len(image_input["data"])
450
            vision_config = self.config.vision_config
451
452
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
453
454
                                           for _ in range(batch_size)])

455
        return [
456
            self._merge_image_patch_embeddings(image_sizes[i],
457
                                               patch_features_batch,
458
                                               strategy="spatial_unpad")
459
            for i, patch_features_batch in enumerate(patch_embeddings)
460
461
        ]

462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:

        if multimodal_embeddings is None:
            return self.language_model.get_input_embeddings(input_ids)

        inputs_embeds = embed_multimodal(
            input_ids,
            self.config.image_token_index,
            self.language_model.model.get_input_embeddings,
            multimodal_embeddings,
        )
        return inputs_embeds

486
487
488
489
490
491
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
492
        intermediate_tensors: Optional[IntermediateTensors] = None,
493
        inputs_embeds: Optional[torch.Tensor] = None,
494
        **kwargs: object,
495
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
496
        """Run forward pass for LlaVA-NeXT.
497
498
499

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
500

501
        Concretely, consider a text prompt:
502
503
504
505
506
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

507
        Tokenizer outputs:
508
509
510
511
512
513
514
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
515
        before they are inputted to the model, so the input processor prepends
516
517
518
519
520
521
522
523
524
525
526
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
527
528
529
530
531
532
533

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
534
            pixel_values: The pixels in each grid patch for each input image.
535
            image_sizes: The original `(height, width)` for each input image.
536

Cyrus Leung's avatar
Cyrus Leung committed
537
        See also:
538
            :class:`LlavaNextImageInputs`
539
        """
540
541
        if intermediate_tensors is not None:
            inputs_embeds = None
542

543
544
545
546
547
548
549
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
550

551
552
553
554
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
555
                                                  intermediate_tensors,
556
                                                  inputs_embeds=inputs_embeds)
557
558
        return hidden_states

559
560
561
562
563
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
564
565
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
566
567
568
569
570
571

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
572
        return self.language_model.sample(logits, sampling_metadata)
573

574
575
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
576
        loader = AutoWeightsLoader(self)
577
        return loader.load_weights(weights)