llava_next.py 23.2 KB
Newer Older
1
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
2
3
4

import torch
import torch.nn as nn
5
from PIL import Image
6
from transformers import CLIPVisionConfig, LlavaNextConfig
7
8
9
10
11
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
12
from vllm.config import CacheConfig, MultiModalConfig
13
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
14
15
16
17
18
19
20
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
from vllm.model_executor.models.clip import CLIPVisionModel
22
23
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
24
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
25
from vllm.sequence import IntermediateTensors, SamplerOutput
26

27
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
28
                   get_clip_patch_grid_length, input_processor_for_clip)
29
from .interfaces import SupportsVision
30
31
from .llava import LlavaMultiModalProjector
from .utils import merge_vision_embeddings
32
33
34
35
36
37
38
39
40
41
42

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}


class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
43
44
45
46
47
48
    data: BatchedTensors
    """
    Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`

    Note that `num_patches` may be different for each batch.
    """
49
50

    image_sizes: NotRequired[torch.Tensor]
51
52
53
54
55
    """
    Shape: `(batch_size, 2)`

    This should be in `(height, width)` format.
    """
56
57


58
LlavaNextImageInputs = LlavaNextImagePixelInputs
59
60


61
62
63
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
# NOTE: new_height and new_width are further incremented to properly invert the
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def _get_llava_next_num_unpadded_features(
    height: int,
    width: int,
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

    aspect_ratio: float = width / height
    current_aspect_ratio: float = current_width / current_height
    if aspect_ratio > current_aspect_ratio:
        new_height = (height * current_width) // width
78
79
        if new_height % 2 == 1:
            new_height += 1
80
81
82
        current_height = new_height
    else:
        new_width = (width * current_height) // height
83
84
        if new_width % 2 == 1:
            new_width += 1
85
86
87
88
89
90
91
        current_width = new_width

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


92
93
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
def get_llava_next_image_feature_size(
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
        base_feature_size = num_patches * num_patches

108
109
110
        # Note: We follow the "wrong" width/height order
        # [ref: PR huggingface/transformers#31588]
        num_patch_width, num_patch_height = get_anyres_image_grid_shape(
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
            image_size=(input_height, input_width),
            grid_pinpoints=hf_config.image_grid_pinpoints,
            patch_size=vision_config.image_size,
        )

        (
            unpadded_feature_size,
            newline_feature_size,
        ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                                  num_patches,
                                                  num_patch_height,
                                                  num_patch_width)

        return unpadded_feature_size + newline_feature_size + base_feature_size

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


130
131
132
133
134
135
136
137
138
139
140
def get_max_llava_next_image_tokens(ctx: InputContext):
    # Result in the max possible feature size (2x2 grid of 336x336px tiles)
    dummy_height = dummy_width = 448

    return get_llava_next_image_feature_size(
        ctx.get_hf_config(LlavaNextConfig),
        input_height=dummy_height,
        input_width=dummy_width,
    )


141
142
143
144
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config

145
146
147
148
149
150
151
    # Result in the max possible feature size (2x2 grid of 336x336px tiles)
    dummy_height = dummy_width = 448
    image_feature_size = get_llava_next_image_feature_size(
        hf_config,
        input_height=dummy_height,
        input_width=dummy_width,
    )
152
153
154
155
156
157
158
159
160

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

161
162
163
164
165
        mm_data = dummy_image_for_clip(
            vision_config,
            image_width_override=dummy_width,
            image_height_override=dummy_height,
        )
166
167
168
169
170
171
172

        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


173
174
175
176
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs
177

178
179
180
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
181

182
183
184
185
186
187
188
189
190
191
192
193
194
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
    elif isinstance(image_data, torch.Tensor):
        raise NotImplementedError("Embeddings input is not supported yet")
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
195

196
    vision_config = hf_config.vision_config
197

198
199
200
201
202
203
204
205
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
206

207
208
    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
209
210


211
@MULTIMODAL_REGISTRY.register_image_input_mapper()
212
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
213
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
214
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
215
216
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):

217
218
    def __init__(self,
                 config: LlavaNextConfig,
219
                 multimodal_config: MultiModalConfig,
220
221
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
222
        super().__init__()
223
224

        self.config = config
225
        self.multimodal_config = multimodal_config
226

227
        # TODO: Optionally initializes this for supporting embeddings.
228
        self.vision_tower = CLIPVisionModel(config=config.vision_config)
229
230
231
232
233
234
235
236
237
238
239
240
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

        self.quant_config = quant_config
        self.language_model = LlamaModel(config.text_config, cache_config,
                                         quant_config)
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
241
242
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config)
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()

        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
        if list(data.shape[1:]) != [2]:
            raise ValueError(
                f"The expected image sizes shape is batch dimension plus "
                f"{[2]}. You supplied {data.shape}.")

        return data

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

        def _validate_shape(data: torch.Tensor):

            dim = data.dim()
            height = width = self.config.vision_config.image_size
            # All 4d image tensors have the same number of patches,
            # so data is a 5d batch of these tensors
            if dim == 5:
                if list(data.shape)[2:] != [
                        3, self.config.vision_config.image_size,
                        self.config.vision_config.image_size
                ]:
                    raise ValueError(
                        "Expected pixel value tensor in shape of: (batch size, "
                        f"patch number, 3, {height}, {width}), got {data.shape}"
                    )

            # 4d image tensors have different number of patches,
            # so data is each individual tensor.
            elif dim == 4:
                if list(data.shape)[1:] != [
                        3, self.config.vision_config.image_size,
                        self.config.vision_config.image_size
                ]:
                    raise ValueError(
                        "Expected pixel value tensor in shape of: (patch "
                        f"number, 3, {height}, {width}), got {data.shape}")
            else:
                raise ValueError(
                    f"Invalid pixel value tensor of shape {data.shape}")

        if isinstance(data, torch.Tensor):
            _validate_shape(data)
        else:
            [_validate_shape(d) for d in data]

        return data

300
    def _parse_and_validate_image_input(
301
            self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
302
303
304
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)

305
        if pixel_values is None:
306
            return None
307

308
        if not isinstance(pixel_values, (torch.Tensor, list)):
309
310
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")
311

312
313
314
        if not isinstance(image_sizes, torch.Tensor):
            raise ValueError("Incorrect type of image sizes. "
                             f"Got type: {type(image_sizes)}")
315

316
317
        return LlavaNextImagePixelInputs(
            type="pixel_values",
318
            data=self._validate_pixel_values(pixel_values),
319
320
            image_sizes=self._validate_image_sizes(image_sizes),
        )
321

Cyrus Leung's avatar
Cyrus Leung committed
322
323
324
325
326
327
328
329
330
331
332
333
334
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
                                  pixel_values: torch.Tensor) -> torch.Tensor:

335
336
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
337
        image_features = vision_tower(pixel_values,
338
                                      self.config.vision_feature_layer)
Cyrus Leung's avatar
Cyrus Leung committed
339
340
341
342
343
344

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

345
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

                # image_aspect_ratio == "anyres"
366
367
                # Note: We follow the "wrong" width/height order
                # [ref: PR huggingface/transformers#31588]
368
                num_patch_width, num_patch_height = get_anyres_image_grid_shape(
369
                    image_size,
370
371
372
373
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                other_patch_embeds = other_patch_embeds \
374
                    .view(num_patch_height, num_patch_width, height, width, -1)
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
                                                     image_size)
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
412
413
414
        self,
        inputs: LlavaNextImagePixelInputs,
    ) -> BatchedTensors:
415
416
417
418
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

419
420
421
422
423
424
425
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
426

427
428
429
430
431
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
432
433
434
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

435
436
437
438
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
439
440

    def _process_image_input(
441
442
            self, image_input: LlavaNextImageInputs) -> BatchedTensors:
        patch_embeddings = self._process_image_pixels(image_input)
443
444
445

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
446
            batch_size = len(image_input["data"])
447
            vision_config = self.config.vision_config
448
449
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
450
451
                                           for _ in range(batch_size)])

452
        return [
453
            self._merge_image_patch_embeddings(image_sizes[i],
454
                                               patch_features_batch,
455
                                               strategy="spatial_unpad")
456
            for i, patch_features_batch in enumerate(patch_embeddings)
457
458
459
460
461
462
463
464
        ]

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
465
        intermediate_tensors: Optional[IntermediateTensors] = None,
466
467
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
468
        """Run forward pass for LlaVA-NeXT.
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
        Concretely, consider a text prompt:
        "<image>\nUSER: What's the content of the image?\nASSISTANT:".
        Tokenizer outputs:
        [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
        2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
        The to-be-inserted image has a size of 576 (24 * 24) along the context
        length dimension.
        `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
        1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
        9047, 13566, 29901].
        There will be 576 `32000` in the `input_ids`.
        (32000 is the token id for `<image>`.)

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
491
            pixel_values: The pixels in each grid patch for each input image.
492
493
                Expects a batch with shape `[1, num_patches, 3, h, w]`.
            image_sizes: The original `(height, width)` for each input image.
Cyrus Leung's avatar
Cyrus Leung committed
494
495
496
497
498
499
500
                Expects a batch with shape `[1, 2]`.

        See also:
            Each input maps to huggingface implementation, as follows:

            - `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L690
            - `image_sizes`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L691
501
502
503
504
505
506
507
508
509
        """
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
            inputs_embeds = self.language_model.get_input_embeddings(input_ids)

            inputs_embeds = merge_vision_embeddings(
                input_ids, inputs_embeds, vision_embeddings,
510
                self.config.image_token_index)
511
512
513
514
515
516
517
518
519

            input_ids = None
        else:
            inputs_embeds = None

        hidden_states = self.language_model(input_ids,
                                            positions,
                                            kv_caches,
                                            attn_metadata,
520
                                            None,
521
522
523
524
525
526
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
527
        logits = self.logits_processor(self.lm_head, hidden_states,
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        # only doing this for language model part for now.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
553
554
555
            # post_layernorm is not needed in CLIPVisionModel
            if "vision_model.post_layernorm" in name:
                continue
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            use_default_weight_loading = False
            if "vision" in name:
                if self.vision_tower is not None:
                    # We only do sharding for language model and
                    # not vision model for now.
                    use_default_weight_loading = True
            else:
                for (param_name, weight_name,
                     shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    param = params_dict[name.replace(weight_name, param_name)]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    use_default_weight_loading = True
            if use_default_weight_loading:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)