aya_vision.py 16.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
4
from collections.abc import Iterable, Mapping, Sequence
5
from typing import Annotated, Literal, Optional, Union
Jennifer Zhao's avatar
Jennifer Zhao committed
6
7
8
9
10
11
12

import torch
from torch import nn
from transformers import BatchFeature, GotOcr2ImageProcessor
from transformers.activations import ACT2FN
from transformers.image_processing_utils import get_size_dict
from transformers.models.aya_vision import AyaVisionConfig
13
from transformers.models.aya_vision.processing_aya_vision import AyaVisionProcessor
Jennifer Zhao's avatar
Jennifer Zhao committed
14
from transformers.models.got_ocr2.image_processing_got_ocr2 import (
15
16
    get_optimal_tiled_canvas,
)
Jennifer Zhao's avatar
Jennifer Zhao committed
17
18

from vllm.config import VllmConfig
19
from vllm.config.multimodal import BaseDummyOptions
Jennifer Zhao's avatar
Jennifer Zhao committed
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
22
23
24
25
26
27
28
29
30
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalFieldConfig,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
31
from vllm.multimodal.profiling import BaseDummyInputsBuilder
Jennifer Zhao's avatar
Jennifer Zhao committed
32
from vllm.sequence import IntermediateTensors
33
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Jennifer Zhao's avatar
Jennifer Zhao committed
34
35
36

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
37
38
39
40
41
42
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
Jennifer Zhao's avatar
Jennifer Zhao committed
43
44


45
class AyaVisionImagePixelInputs(TensorSchema):
Jennifer Zhao's avatar
Jennifer Zhao committed
46
    """
47
48
49
50
51
52
53
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - c: Number of channels
        - h: Height of each image patch
        - w: Width of each image patch
        - bn: Batch size * number of images
Jennifer Zhao's avatar
Jennifer Zhao committed
54
55
    """

56
57
58
59
60
61
62
63
64
65
66
    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", 3, "h", "w"),
    ]

    num_patches: Annotated[
        torch.Tensor,
        TensorShape("bn"),
    ]
Jennifer Zhao's avatar
Jennifer Zhao committed
67
68
69
70
71
72
73
74


class AyaVisionMultiModalProjector(nn.Module):
    def __init__(self, config: AyaVisionConfig):
        super().__init__()
        self.config = config
        self.downsample_factor = config.downsample_factor
        self.alignment_intermediate_size = getattr(
75
76
77
78
79
80
            config, "alignment_intermediate_size", config.text_config.hidden_size
        )
        self.layernorm = nn.LayerNorm(
            config.vision_config.hidden_size * (config.downsample_factor**2),
            eps=config.adapter_layer_norm_eps,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
81
82
83
84
85
86
87
88
89

        self.linear_1 = nn.Linear(
            config.vision_config.hidden_size * (config.downsample_factor**2),
            self.alignment_intermediate_size,
            bias=True,
        )

        self.act = ACT2FN["silu"]  # SwiGLU uses SiLU activation
        # For SwiGLU, project down to half size since we split intermediate dim
90
91
92
93
94
        self.linear_2 = nn.Linear(
            self.alignment_intermediate_size // 2,
            config.text_config.hidden_size,
            bias=True,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
95
96
97
98
99
100
101
102
103
104
105
106
107

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        image_features = self.pixel_shuffle(image_features)
        image_features = self.layernorm(image_features)
        hidden_states = self.linear_1(image_features)

        # Split along last dimension and apply SwiGLU
        x, gate = hidden_states.chunk(2, dim=-1)
        hidden_states = self.act(gate) * x

        hidden_states = self.linear_2(hidden_states)
        return hidden_states

108
    def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor:  # B, S, D
Jennifer Zhao's avatar
Jennifer Zhao committed
109
110
        batch_size, seq_length, _ = image_features.shape
        height = width = int(seq_length**0.5)
111
112
113
        image_features = image_features.reshape(
            image_features.shape[0], width, height, -1
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
114
115
        channels = image_features.shape[-1]
        image_features = image_features.reshape(
116
117
118
119
120
            batch_size,
            width,
            int(height / self.downsample_factor),
            int(channels * self.downsample_factor),
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
121
122
        image_features = image_features.permute(0, 2, 1, 3)
        image_features = image_features.reshape(
123
124
125
126
127
            batch_size,
            int(height / self.downsample_factor),
            int(width / self.downsample_factor),
            -1,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
128
129
130
131
132
133
134
135
136
        image_features = image_features.permute(0, 2, 1, 3)
        return image_features


class AyaVisionProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> AyaVisionConfig:
        return self.ctx.get_hf_config(AyaVisionConfig)

    def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor:
137
        return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
138

139
140
    def get_image_processor(self, **kwargs: object) -> GotOcr2ImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
Jennifer Zhao's avatar
Jennifer Zhao committed
141
142
143
144
145
146

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
147
148
        height = image_processor.size["height"]
        width = image_processor.size["width"]
Jennifer Zhao's avatar
Jennifer Zhao committed
149
        max_patches = image_processor.max_patches
150
        return ImageSize(height=height * max_patches, width=width * max_patches)
Jennifer Zhao's avatar
Jennifer Zhao committed
151

152
153
154
155
156
157
158
159
160
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
        size: dict,
        min_patches: int,
        max_patches: int,
    ) -> int:
Jennifer Zhao's avatar
Jennifer Zhao committed
161
162
163
164
165
166
167
        """
        Calculate the number of patches needed for a given image based on size
        constraints.  This method replicates and adjusts the logic from:
        transformers/models/got_ocr2/image_processing_got_ocr2
        """
        size = get_size_dict(size, default_to_square=False)
        num_columns, num_rows = get_optimal_tiled_canvas(
168
169
170
171
172
            (image_height, image_width),
            (size["height"], size["width"]),
            min_patches,
            max_patches,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
173
174
175
176
        num_blocks = num_columns * num_rows
        return num_blocks if num_blocks == 1 else num_blocks + 1


177
class AyaVisionDummyInputsBuilder(BaseDummyInputsBuilder[AyaVisionProcessingInfo]):
178
179
180
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

Jennifer Zhao's avatar
Jennifer Zhao committed
181
182
183
        processor = self.info.get_hf_processor()
        image_token = processor.image_token

184
185
186
187
188
189
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
190
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
191
    ) -> MultiModalDataDict:
Jennifer Zhao's avatar
Jennifer Zhao committed
192
        num_images = mm_counts.get("image", 0)
193
        image_size = self.info.get_image_size_with_most_features()
Jennifer Zhao's avatar
Jennifer Zhao committed
194

195
196
        image_overrides = mm_options.get("image") if mm_options else None

197
        return {
198
199
200
201
202
203
            "image": self._get_dummy_images(
                width=image_size.width,
                height=image_size.height,
                num_images=num_images,
                overrides=image_overrides,
            )
Jennifer Zhao's avatar
Jennifer Zhao committed
204
205
206
        }


207
class AyaVisionMultiModalProcessor(BaseMultiModalProcessor[AyaVisionProcessingInfo]):
Jennifer Zhao's avatar
Jennifer Zhao committed
208
209
210
211
212
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
213
        tok_kwargs: Mapping[str, object],
Jennifer Zhao's avatar
Jennifer Zhao committed
214
215
216
217
218
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
219
            tok_kwargs,
Jennifer Zhao's avatar
Jennifer Zhao committed
220
221
222
223
224
        )
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_processor = hf_processor.image_processor

        # HF processor pops the `num_patches` kwarg, which is needed by vLLM
225
        if (images := mm_data.get("images")) is not None:
226
227
228
229
230
            parsed_images = (
                self._get_data_parser()
                .parse_mm_data({"image": images})
                .get_items("image", ImageProcessorItems)
            )
Jennifer Zhao's avatar
Jennifer Zhao committed
231
            image_sizes = [
232
                parsed_images.get_image_size(i) for i in range(len(parsed_images))
Jennifer Zhao's avatar
Jennifer Zhao committed
233
            ]
234

Jennifer Zhao's avatar
Jennifer Zhao committed
235
236
237
238
239
240
            num_patches = [
                self.info.get_num_patches(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    size=image_processor.size,
                    min_patches=image_processor.min_patches,
241
242
                    max_patches=image_processor.max_patches,
                )
Jennifer Zhao's avatar
Jennifer Zhao committed
243
244
245
246
247
248
249
250
251
252
253
254
255
                for image_size in image_sizes
            ]
            processed_outputs["num_patches"] = torch.tensor(num_patches)

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))
        return dict(
256
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
Jennifer Zhao's avatar
Jennifer Zhao committed
257
258
259
260
261
262
263
264
            num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
265
        out_mm_kwargs: MultiModalKwargsItems,
Jennifer Zhao's avatar
Jennifer Zhao committed
266
267
268
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
269
        img_patch_token = hf_processor.img_patch_token
Jennifer Zhao's avatar
Jennifer Zhao committed
270
271
272
        image_processor = hf_processor.image_processor

        def get_replacement(item_idx: int):
273
            images = mm_items.get_items("image", ImageProcessorItems)
Jennifer Zhao's avatar
Jennifer Zhao committed
274
275
276
277
278
279
            image_size: ImageSize = images.get_image_size(item_idx)
            num_patches = self.info.get_num_patches(
                image_width=image_size.width,
                image_height=image_size.height,
                size=image_processor.size,
                min_patches=image_processor.min_patches,
280
281
282
283
284
                max_patches=image_processor.max_patches,
            )
            repl = hf_processor._prompt_split_image(num_patches=num_patches)

            return PromptUpdateDetails.select_text(repl, img_patch_token)
Jennifer Zhao's avatar
Jennifer Zhao committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int:
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest m
    elif isinstance(feature_layers, (list, tuple)):
303
304
305
306
        return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
Jennifer Zhao's avatar
Jennifer Zhao committed
307
308
309
310
311
312
313
314
315
316
317


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index


@MULTIMODAL_REGISTRY.register_processor(
    AyaVisionMultiModalProcessor,
    info=AyaVisionProcessingInfo,
318
319
320
    dummy_inputs=AyaVisionDummyInputsBuilder,
)
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
321
    merge_by_field_config = True
Jennifer Zhao's avatar
Jennifer Zhao committed
322

323
324
325
326
327
328
329
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
330
331
        }
    )
332

333
334
335
336
337
338
339
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

Jennifer Zhao's avatar
Jennifer Zhao committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: AyaVisionConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        num_hidden_layers = _get_num_hidden_layers(config)
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config

        self.vision_tower = SiglipVisionModel(
            config.vision_config,
            quant_config,
            num_hidden_layers_override=num_hidden_layers,
354
355
            prefix=maybe_prefix(prefix, "vision_model"),
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
356
357
358
359
360
361
362
        self.vocab_size = config.text_config.vocab_size
        self.multi_modal_projector = AyaVisionMultiModalProjector(config)
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "model"),
            # Cohere2ForCausalLM and CohereForCausalLM are the same on vllm
363
364
            architectures=["Cohere2ForCausalLM"],
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
365
366
367
368
369

    @property
    def dtype(self):
        return next(self.parameters()).dtype

370
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Jennifer Zhao's avatar
Jennifer Zhao committed
371
        loader = AutoWeightsLoader(self)
372
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
Jennifer Zhao's avatar
Jennifer Zhao committed
373

374
375
376
377
378
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
379
380
381
382
        return vision_tower(
            pixel_values.to(dtype=vision_tower.dtype),
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
383

384
385
386
    def _process_image_input(
        self, image_input: AyaVisionImagePixelInputs, **kwargs
    ) -> list[torch.Tensor]:
Jennifer Zhao's avatar
Jennifer Zhao committed
387
388
389
390
        assert self.vision_tower is not None
        pixel_values = image_input["pixel_values"]
        num_patches = image_input["num_patches"]
        image_features = self._image_pixels_to_features(
391
392
            self.vision_tower, pixel_values=pixel_values
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
393
        image_embeds = self.multi_modal_projector(image_features)
394
        return [e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())]
Jennifer Zhao's avatar
Jennifer Zhao committed
395
396

    def _parse_and_validate_image_input(
397
398
        self, **kwargs: object
    ) -> Optional[AyaVisionImagePixelInputs]:
Jennifer Zhao's avatar
Jennifer Zhao committed
399
400
401
402
403
        pixel_values = kwargs.pop("pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)
        assert image_embeds is None, "Aya Vision does not support image_embeds."

404
405
        if pixel_values is None:
            return None
406

Jennifer Zhao's avatar
Jennifer Zhao committed
407
408
        return AyaVisionImagePixelInputs(
            type="pixel_values",
409
410
            pixel_values=pixel_values,
            num_patches=num_patches,
411
412
413
            resolve_bindings={
                "h": self.config.vision_config.image_size,
                "w": self.config.vision_config.image_size,
414
415
            },
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
416

417
418
419
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

420
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
Jennifer Zhao's avatar
Jennifer Zhao committed
421
422
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
423
            return []
424
425

        return self._process_image_input(image_input, **kwargs)
Jennifer Zhao's avatar
Jennifer Zhao committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
450
        return self.language_model.compute_logits(hidden_states)