aya_vision.py 16.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
4
from collections.abc import Iterable, Mapping, Sequence
5
from typing import Annotated, Literal
Jennifer Zhao's avatar
Jennifer Zhao committed
6
7
8
9
10
11
12

import torch
from torch import nn
from transformers import BatchFeature, GotOcr2ImageProcessor
from transformers.activations import ACT2FN
from transformers.image_processing_utils import get_size_dict
from transformers.models.aya_vision import AyaVisionConfig
13
from transformers.models.aya_vision.processing_aya_vision import AyaVisionProcessor
Jennifer Zhao's avatar
Jennifer Zhao committed
14
from transformers.models.got_ocr2.image_processing_got_ocr2 import (
15
16
    get_optimal_tiled_canvas,
)
Jennifer Zhao's avatar
Jennifer Zhao committed
17
18

from vllm.config import VllmConfig
19
from vllm.config.multimodal import BaseDummyOptions
Jennifer Zhao's avatar
Jennifer Zhao committed
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
22
23
24
25
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
26
27
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
28
    BaseDummyInputsBuilder,
29
30
31
32
33
34
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
Jennifer Zhao's avatar
Jennifer Zhao committed
35
from vllm.sequence import IntermediateTensors
36
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Jennifer Zhao's avatar
Jennifer Zhao committed
37
38
39

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
40
41
42
43
44
45
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
Jennifer Zhao's avatar
Jennifer Zhao committed
46
47


48
class AyaVisionImagePixelInputs(TensorSchema):
Jennifer Zhao's avatar
Jennifer Zhao committed
49
    """
50
51
52
53
54
55
56
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - c: Number of channels
        - h: Height of each image patch
        - w: Width of each image patch
        - bn: Batch size * number of images
Jennifer Zhao's avatar
Jennifer Zhao committed
57
58
    """

59
60
61
62
63
64
65
66
67
68
69
    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", 3, "h", "w"),
    ]

    num_patches: Annotated[
        torch.Tensor,
        TensorShape("bn"),
    ]
Jennifer Zhao's avatar
Jennifer Zhao committed
70
71
72
73
74
75
76
77


class AyaVisionMultiModalProjector(nn.Module):
    def __init__(self, config: AyaVisionConfig):
        super().__init__()
        self.config = config
        self.downsample_factor = config.downsample_factor
        self.alignment_intermediate_size = getattr(
78
79
80
81
82
83
            config, "alignment_intermediate_size", config.text_config.hidden_size
        )
        self.layernorm = nn.LayerNorm(
            config.vision_config.hidden_size * (config.downsample_factor**2),
            eps=config.adapter_layer_norm_eps,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
84
85
86
87
88
89
90
91
92

        self.linear_1 = nn.Linear(
            config.vision_config.hidden_size * (config.downsample_factor**2),
            self.alignment_intermediate_size,
            bias=True,
        )

        self.act = ACT2FN["silu"]  # SwiGLU uses SiLU activation
        # For SwiGLU, project down to half size since we split intermediate dim
93
94
95
96
97
        self.linear_2 = nn.Linear(
            self.alignment_intermediate_size // 2,
            config.text_config.hidden_size,
            bias=True,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
98
99
100
101
102
103
104
105
106
107
108
109
110

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        image_features = self.pixel_shuffle(image_features)
        image_features = self.layernorm(image_features)
        hidden_states = self.linear_1(image_features)

        # Split along last dimension and apply SwiGLU
        x, gate = hidden_states.chunk(2, dim=-1)
        hidden_states = self.act(gate) * x

        hidden_states = self.linear_2(hidden_states)
        return hidden_states

111
    def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor:  # B, S, D
Jennifer Zhao's avatar
Jennifer Zhao committed
112
113
        batch_size, seq_length, _ = image_features.shape
        height = width = int(seq_length**0.5)
114
115
116
        image_features = image_features.reshape(
            image_features.shape[0], width, height, -1
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
117
118
        channels = image_features.shape[-1]
        image_features = image_features.reshape(
119
120
121
122
123
            batch_size,
            width,
            int(height / self.downsample_factor),
            int(channels * self.downsample_factor),
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
124
125
        image_features = image_features.permute(0, 2, 1, 3)
        image_features = image_features.reshape(
126
127
128
129
130
            batch_size,
            int(height / self.downsample_factor),
            int(width / self.downsample_factor),
            -1,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
131
132
133
134
135
136
137
138
139
        image_features = image_features.permute(0, 2, 1, 3)
        return image_features


class AyaVisionProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> AyaVisionConfig:
        return self.ctx.get_hf_config(AyaVisionConfig)

    def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor:
140
        return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
141

142
143
    def get_image_processor(self, **kwargs: object) -> GotOcr2ImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
Jennifer Zhao's avatar
Jennifer Zhao committed
144

145
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Jennifer Zhao's avatar
Jennifer Zhao committed
146
147
148
149
        return {"image": None}

    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
150
151
        height = image_processor.size["height"]
        width = image_processor.size["width"]
Jennifer Zhao's avatar
Jennifer Zhao committed
152
        max_patches = image_processor.max_patches
153
        return ImageSize(height=height * max_patches, width=width * max_patches)
Jennifer Zhao's avatar
Jennifer Zhao committed
154

155
156
157
158
159
160
161
162
163
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
        size: dict,
        min_patches: int,
        max_patches: int,
    ) -> int:
Jennifer Zhao's avatar
Jennifer Zhao committed
164
165
166
167
168
169
170
        """
        Calculate the number of patches needed for a given image based on size
        constraints.  This method replicates and adjusts the logic from:
        transformers/models/got_ocr2/image_processing_got_ocr2
        """
        size = get_size_dict(size, default_to_square=False)
        num_columns, num_rows = get_optimal_tiled_canvas(
171
172
173
174
175
            (image_height, image_width),
            (size["height"], size["width"]),
            min_patches,
            max_patches,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
176
177
178
179
        num_blocks = num_columns * num_rows
        return num_blocks if num_blocks == 1 else num_blocks + 1


180
class AyaVisionDummyInputsBuilder(BaseDummyInputsBuilder[AyaVisionProcessingInfo]):
181
182
183
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

Jennifer Zhao's avatar
Jennifer Zhao committed
184
185
186
        processor = self.info.get_hf_processor()
        image_token = processor.image_token

187
188
189
190
191
192
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
193
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
194
    ) -> MultiModalDataDict:
Jennifer Zhao's avatar
Jennifer Zhao committed
195
        num_images = mm_counts.get("image", 0)
196
        image_size = self.info.get_image_size_with_most_features()
Jennifer Zhao's avatar
Jennifer Zhao committed
197

198
199
        image_overrides = mm_options.get("image") if mm_options else None

200
        return {
201
202
203
204
205
206
            "image": self._get_dummy_images(
                width=image_size.width,
                height=image_size.height,
                num_images=num_images,
                overrides=image_overrides,
            )
Jennifer Zhao's avatar
Jennifer Zhao committed
207
208
209
        }


210
class AyaVisionMultiModalProcessor(BaseMultiModalProcessor[AyaVisionProcessingInfo]):
Jennifer Zhao's avatar
Jennifer Zhao committed
211
212
213
214
215
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
216
        tok_kwargs: Mapping[str, object],
Jennifer Zhao's avatar
Jennifer Zhao committed
217
218
219
220
221
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
222
            tok_kwargs,
Jennifer Zhao's avatar
Jennifer Zhao committed
223
224
225
226
227
        )
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_processor = hf_processor.image_processor

        # HF processor pops the `num_patches` kwarg, which is needed by vLLM
228
        if (images := mm_data.get("images")) is not None:
229
230
231
232
233
            parsed_images = (
                self._get_data_parser()
                .parse_mm_data({"image": images})
                .get_items("image", ImageProcessorItems)
            )
Jennifer Zhao's avatar
Jennifer Zhao committed
234
            image_sizes = [
235
                parsed_images.get_image_size(i) for i in range(len(parsed_images))
Jennifer Zhao's avatar
Jennifer Zhao committed
236
            ]
237

Jennifer Zhao's avatar
Jennifer Zhao committed
238
239
240
241
242
243
            num_patches = [
                self.info.get_num_patches(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    size=image_processor.size,
                    min_patches=image_processor.min_patches,
244
245
                    max_patches=image_processor.max_patches,
                )
Jennifer Zhao's avatar
Jennifer Zhao committed
246
247
248
249
250
251
252
253
254
255
256
257
258
                for image_size in image_sizes
            ]
            processed_outputs["num_patches"] = torch.tensor(num_patches)

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))
        return dict(
259
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
Jennifer Zhao's avatar
Jennifer Zhao committed
260
261
262
263
264
265
266
267
            num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
268
        out_mm_kwargs: MultiModalKwargsItems,
Jennifer Zhao's avatar
Jennifer Zhao committed
269
270
271
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
272
        img_patch_token = hf_processor.img_patch_token
Jennifer Zhao's avatar
Jennifer Zhao committed
273
274
275
        image_processor = hf_processor.image_processor

        def get_replacement(item_idx: int):
276
            images = mm_items.get_items("image", ImageProcessorItems)
Jennifer Zhao's avatar
Jennifer Zhao committed
277
278
279
280
281
282
            image_size: ImageSize = images.get_image_size(item_idx)
            num_patches = self.info.get_num_patches(
                image_width=image_size.width,
                image_height=image_size.height,
                size=image_processor.size,
                min_patches=image_processor.min_patches,
283
284
285
286
287
                max_patches=image_processor.max_patches,
            )
            repl = hf_processor._prompt_split_image(num_patches=num_patches)

            return PromptUpdateDetails.select_text(repl, img_patch_token)
Jennifer Zhao's avatar
Jennifer Zhao committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int:
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest m
    elif isinstance(feature_layers, (list, tuple)):
306
307
308
309
        return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
Jennifer Zhao's avatar
Jennifer Zhao committed
310
311
312
313
314
315
316
317
318
319
320


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index


@MULTIMODAL_REGISTRY.register_processor(
    AyaVisionMultiModalProcessor,
    info=AyaVisionProcessingInfo,
321
322
323
    dummy_inputs=AyaVisionDummyInputsBuilder,
)
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
324
325
326
327
328
329
330
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
331
332
        }
    )
333

334
    @classmethod
335
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
336
337
338
339
340
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

Jennifer Zhao's avatar
Jennifer Zhao committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: AyaVisionConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        num_hidden_layers = _get_num_hidden_layers(config)
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config

        self.vision_tower = SiglipVisionModel(
            config.vision_config,
            quant_config,
            num_hidden_layers_override=num_hidden_layers,
355
356
            prefix=maybe_prefix(prefix, "vision_model"),
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
357
358
359
360
361
362
363
        self.vocab_size = config.text_config.vocab_size
        self.multi_modal_projector = AyaVisionMultiModalProjector(config)
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "model"),
            # Cohere2ForCausalLM and CohereForCausalLM are the same on vllm
364
365
            architectures=["Cohere2ForCausalLM"],
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
366
367
368
369
370

    @property
    def dtype(self):
        return next(self.parameters()).dtype

371
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Jennifer Zhao's avatar
Jennifer Zhao committed
372
        loader = AutoWeightsLoader(self)
373
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
Jennifer Zhao's avatar
Jennifer Zhao committed
374

375
376
377
378
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
379
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
380
381
382
383
        return vision_tower(
            pixel_values.to(dtype=vision_tower.dtype),
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
384

385
386
387
    def _process_image_input(
        self, image_input: AyaVisionImagePixelInputs, **kwargs
    ) -> list[torch.Tensor]:
Jennifer Zhao's avatar
Jennifer Zhao committed
388
389
390
391
        assert self.vision_tower is not None
        pixel_values = image_input["pixel_values"]
        num_patches = image_input["num_patches"]
        image_features = self._image_pixels_to_features(
392
393
            self.vision_tower, pixel_values=pixel_values
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
394
        image_embeds = self.multi_modal_projector(image_features)
395
        return [e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())]
Jennifer Zhao's avatar
Jennifer Zhao committed
396
397

    def _parse_and_validate_image_input(
398
        self, **kwargs: object
399
    ) -> AyaVisionImagePixelInputs | None:
Jennifer Zhao's avatar
Jennifer Zhao committed
400
401
402
403
404
        pixel_values = kwargs.pop("pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)
        assert image_embeds is None, "Aya Vision does not support image_embeds."

405
406
        if pixel_values is None:
            return None
407

Jennifer Zhao's avatar
Jennifer Zhao committed
408
409
        return AyaVisionImagePixelInputs(
            type="pixel_values",
410
411
            pixel_values=pixel_values,
            num_patches=num_patches,
412
413
414
            resolve_bindings={
                "h": self.config.vision_config.image_size,
                "w": self.config.vision_config.image_size,
415
416
            },
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
417

418
419
420
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

421
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
Jennifer Zhao's avatar
Jennifer Zhao committed
422
423
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
424
            return []
425
426

        return self._process_image_input(image_input, **kwargs)
Jennifer Zhao's avatar
Jennifer Zhao committed
427
428
429
430
431

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
432
433
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Jennifer Zhao's avatar
Jennifer Zhao committed
434
        **kwargs: object,
435
    ) -> torch.Tensor | IntermediateTensors:
Jennifer Zhao's avatar
Jennifer Zhao committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
450
    ) -> torch.Tensor | None:
451
        return self.language_model.compute_logits(hidden_states)