aya_vision.py 18.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
4
from collections.abc import Iterable, Mapping, Sequence
5
from typing import Annotated, Literal, Optional, Union, cast
Jennifer Zhao's avatar
Jennifer Zhao committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import torch
from torch import nn
from transformers import BatchFeature, GotOcr2ImageProcessor
from transformers.activations import ACT2FN
from transformers.image_processing_utils import get_size_dict
from transformers.models.aya_vision import AyaVisionConfig
from transformers.models.aya_vision.processing_aya_vision import (
    AyaVisionProcessor)
from transformers.models.got_ocr2.image_processing_got_ocr2 import (
    get_optimal_tiled_canvas)

from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
20
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
Jennifer Zhao's avatar
Jennifer Zhao committed
21
22
23
24
25
26
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo,
                                        MultiModalFieldConfig,
                                        PromptReplacement, PromptUpdate,
27
                                        PromptUpdateDetails)
28
from vllm.multimodal.profiling import BaseDummyInputsBuilder
Jennifer Zhao's avatar
Jennifer Zhao committed
29
from vllm.sequence import IntermediateTensors
30
from vllm.utils.jsontree import json_map_leaves
31
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Jennifer Zhao's avatar
Jennifer Zhao committed
32
33
34

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
35
36
37
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
Jennifer Zhao's avatar
Jennifer Zhao committed
38
39


40
class AyaVisionImagePixelInputs(TensorSchema):
Jennifer Zhao's avatar
Jennifer Zhao committed
41
    """
42
43
44
45
46
47
48
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - c: Number of channels
        - h: Height of each image patch
        - w: Width of each image patch
        - bn: Batch size * number of images
Jennifer Zhao's avatar
Jennifer Zhao committed
49
50
    """

51
52
53
54
55
56
57
58
59
60
61
    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", 3, "h", "w"),
    ]

    num_patches: Annotated[
        torch.Tensor,
        TensorShape("bn"),
    ]
Jennifer Zhao's avatar
Jennifer Zhao committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


class AyaVisionMultiModalProjector(nn.Module):

    def __init__(self, config: AyaVisionConfig):
        super().__init__()
        self.config = config
        self.downsample_factor = config.downsample_factor
        self.alignment_intermediate_size = getattr(
            config, "alignment_intermediate_size",
            config.text_config.hidden_size)
        self.layernorm = nn.LayerNorm(config.vision_config.hidden_size *
                                      (config.downsample_factor**2),
                                      eps=config.adapter_layer_norm_eps)

        self.linear_1 = nn.Linear(
            config.vision_config.hidden_size * (config.downsample_factor**2),
            self.alignment_intermediate_size,
            bias=True,
        )

        self.act = ACT2FN["silu"]  # SwiGLU uses SiLU activation
        # For SwiGLU, project down to half size since we split intermediate dim
        self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2,
                                  config.text_config.hidden_size,
                                  bias=True)

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        image_features = self.pixel_shuffle(image_features)
        image_features = self.layernorm(image_features)
        hidden_states = self.linear_1(image_features)

        # Split along last dimension and apply SwiGLU
        x, gate = hidden_states.chunk(2, dim=-1)
        hidden_states = self.act(gate) * x

        hidden_states = self.linear_2(hidden_states)
        return hidden_states

    def pixel_shuffle(self,
                      image_features: torch.Tensor) -> torch.Tensor:  # B, S, D
        batch_size, seq_length, _ = image_features.shape
        height = width = int(seq_length**0.5)
        image_features = image_features.reshape(image_features.shape[0], width,
                                                height, -1)
        channels = image_features.shape[-1]
        image_features = image_features.reshape(
            batch_size, width, int(height / self.downsample_factor),
            int(channels * self.downsample_factor))
        image_features = image_features.permute(0, 2, 1, 3)
        image_features = image_features.reshape(
            batch_size, int(height / self.downsample_factor),
            int(width / self.downsample_factor), -1)
        image_features = image_features.permute(0, 2, 1, 3)
        return image_features


class AyaVisionProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self) -> AyaVisionConfig:
        return self.ctx.get_hf_config(AyaVisionConfig)

    def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor:
125
        return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
126

127
128
    def get_image_processor(self, **kwargs: object) -> GotOcr2ImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
Jennifer Zhao's avatar
Jennifer Zhao committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
        height = image_processor.size['height']
        width = image_processor.size['width']
        max_patches = image_processor.max_patches
        return ImageSize(height=height * max_patches,
                         width=width * max_patches)

    def get_num_patches(self, *, image_width: int, image_height: int,
                        size: dict, min_patches: int, max_patches: int) -> int:
        """
        Calculate the number of patches needed for a given image based on size
        constraints.  This method replicates and adjusts the logic from:
        transformers/models/got_ocr2/image_processing_got_ocr2
        """
        size = get_size_dict(size, default_to_square=False)
        num_columns, num_rows = get_optimal_tiled_canvas(
            (image_height, image_width), (size["height"], size["width"]),
            min_patches, max_patches)
        num_blocks = num_columns * num_rows
        return num_blocks if num_blocks == 1 else num_blocks + 1


class AyaVisionDummyInputsBuilder(
        BaseDummyInputsBuilder[AyaVisionProcessingInfo]):

159
160
161
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

Jennifer Zhao's avatar
Jennifer Zhao committed
162
163
164
        processor = self.info.get_hf_processor()
        image_token = processor.image_token

165
166
167
168
169
170
171
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
Jennifer Zhao's avatar
Jennifer Zhao committed
172
173
174
175
        num_images = mm_counts.get("image", 0)
        image_size = \
            self.info.get_image_size_with_most_features()

176
        return {
Jennifer Zhao's avatar
Jennifer Zhao committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            "image":
            self._get_dummy_images(width=image_size.width,
                                   height=image_size.height,
                                   num_images=num_images)
        }


class AyaVisionMultiModalProcessor(
        BaseMultiModalProcessor[AyaVisionProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
192
        tok_kwargs: Mapping[str, object],
Jennifer Zhao's avatar
Jennifer Zhao committed
193
194
195
196
197
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
198
            tok_kwargs,
Jennifer Zhao's avatar
Jennifer Zhao committed
199
200
201
202
203
        )
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_processor = hf_processor.image_processor

        # HF processor pops the `num_patches` kwarg, which is needed by vLLM
204
        if (images := mm_data.get("images")) is not None:
Jennifer Zhao's avatar
Jennifer Zhao committed
205
206
207
208
209
210
211
212
            parsed_images = (self._get_data_parser().parse_mm_data({
                "image":
                images
            }).get_items("image", ImageProcessorItems))
            image_sizes = [
                parsed_images.get_image_size(i)
                for i in range(len(parsed_images))
            ]
213

Jennifer Zhao's avatar
Jennifer Zhao committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
            num_patches = [
                self.info.get_num_patches(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    size=image_processor.size,
                    min_patches=image_processor.min_patches,
                    max_patches=image_processor.max_patches)
                for image_size in image_sizes
            ]
            processed_outputs["num_patches"] = torch.tensor(num_patches)

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", num_patches),
            num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
244
        out_mm_kwargs: MultiModalKwargsItems,
Jennifer Zhao's avatar
Jennifer Zhao committed
245
246
247
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
248
        img_patch_token = hf_processor.img_patch_token
Jennifer Zhao's avatar
Jennifer Zhao committed
249
250
251
        image_processor = hf_processor.image_processor

        def get_replacement(item_idx: int):
252
            images = mm_items.get_items("image", ImageProcessorItems)
Jennifer Zhao's avatar
Jennifer Zhao committed
253
254
255
256
257
258
            image_size: ImageSize = images.get_image_size(item_idx)
            num_patches = self.info.get_num_patches(
                image_width=image_size.width,
                image_height=image_size.height,
                size=image_processor.size,
                min_patches=image_processor.min_patches,
259
260
261
262
263
                max_patches=image_processor.max_patches,
            )
            repl = hf_processor._prompt_split_image(num_patches=num_patches)

            return PromptUpdateDetails.select_text(repl, img_patch_token)
Jennifer Zhao's avatar
Jennifer Zhao committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int:
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest m
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index


@MULTIMODAL_REGISTRY.register_processor(
    AyaVisionMultiModalProcessor,
    info=AyaVisionProcessingInfo,
    dummy_inputs=AyaVisionDummyInputsBuilder)
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):

301
302
303
304
305
306
307
308
309
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

310
311
312
313
314
315
316
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

Jennifer Zhao's avatar
Jennifer Zhao committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: AyaVisionConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        num_hidden_layers = _get_num_hidden_layers(config)
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config

        self.vision_tower = SiglipVisionModel(
            config.vision_config,
            quant_config,
            num_hidden_layers_override=num_hidden_layers,
            prefix=maybe_prefix(prefix, "vision_model"))
        self.vocab_size = config.text_config.vocab_size
        self.multi_modal_projector = AyaVisionMultiModalProjector(config)
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "model"),
            # Cohere2ForCausalLM and CohereForCausalLM are the same on vllm
            architectures=["Cohere2ForCausalLM"])

    @property
    def dtype(self):
        return next(self.parameters()).dtype

345
346
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
Jennifer Zhao's avatar
Jennifer Zhao committed
347
        loader = AutoWeightsLoader(self)
348
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
Jennifer Zhao's avatar
Jennifer Zhao committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

    def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
                                  pixel_values: torch.Tensor,
                                  **kwargs) -> torch.Tensor:
        target_dtype = vision_tower.get_input_embeddings().weight.dtype
        image_features = vision_tower(pixel_values.to(dtype=target_dtype),
                                      **kwargs)

        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
        )

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _process_image_input(self, image_input: AyaVisionImagePixelInputs,
                             **kwargs) -> list[torch.Tensor]:
        assert self.vision_tower is not None
        pixel_values = image_input["pixel_values"]
        num_patches = image_input["num_patches"]
        image_features = self._image_pixels_to_features(
            self.vision_tower, pixel_values=pixel_values)
        image_embeds = self.multi_modal_projector(image_features)
        return [
            e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())
        ]

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)
        assert image_embeds is None, "Aya Vision does not support image_embeds."

396
397
        if pixel_values is None:
            return None
398

Jennifer Zhao's avatar
Jennifer Zhao committed
399
400
        return AyaVisionImagePixelInputs(
            type="pixel_values",
401
402
403
404
405
406
            pixel_values=flatten_bn(pixel_values, concat=True),
            num_patches=flatten_bn(num_patches, concat=True),
            resolve_bindings={
                "h": self.config.vision_config.image_size,
                "w": self.config.vision_config.image_size,
            })
Jennifer Zhao's avatar
Jennifer Zhao committed
407

408
409
410
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

411
412
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
Jennifer Zhao's avatar
Jennifer Zhao committed
413
414
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
415
            return []
416
417

        return self._process_image_input(image_input, **kwargs)
Jennifer Zhao's avatar
Jennifer Zhao committed
418
419
420
421
422
423
424

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
425
426
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
Jennifer Zhao's avatar
Jennifer Zhao committed
427
428
429
            inputs_embeds = merge_multimodal_embeddings(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
430
431
432
                multimodal_embeddings=multimodal_embeddings,
                placeholder_token_id=self.config.image_token_index,
            )
Jennifer Zhao's avatar
Jennifer Zhao committed
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
467
        return self.language_model.compute_logits(hidden_states)