aya_vision.py 15.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
4
from collections.abc import Iterable, Mapping, Sequence
5
from typing import Annotated, Literal
Jennifer Zhao's avatar
Jennifer Zhao committed
6
7
8
9
10
11
12

import torch
from torch import nn
from transformers import BatchFeature, GotOcr2ImageProcessor
from transformers.activations import ACT2FN
from transformers.image_processing_utils import get_size_dict
from transformers.models.aya_vision import AyaVisionConfig
13
from transformers.models.aya_vision.processing_aya_vision import AyaVisionProcessor
Jennifer Zhao's avatar
Jennifer Zhao committed
14
from transformers.models.got_ocr2.image_processing_got_ocr2 import (
15
16
    get_optimal_tiled_canvas,
)
Jennifer Zhao's avatar
Jennifer Zhao committed
17
18

from vllm.config import VllmConfig
19
from vllm.config.multimodal import BaseDummyOptions
20
from vllm.inputs import MultiModalDataDict
Jennifer Zhao's avatar
Jennifer Zhao committed
21
from vllm.multimodal import MULTIMODAL_REGISTRY
22
23
24
25
from vllm.multimodal.inputs import (
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
26
27
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
28
    BaseDummyInputsBuilder,
29
30
31
32
33
34
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
Jennifer Zhao's avatar
Jennifer Zhao committed
35
from vllm.sequence import IntermediateTensors
36
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Jennifer Zhao's avatar
Jennifer Zhao committed
37
38
39

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
40
41
42
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
43
    get_layer_index,
44
45
46
    init_vllm_registered_model,
    maybe_prefix,
)
Jennifer Zhao's avatar
Jennifer Zhao committed
47
48


49
class AyaVisionImagePixelInputs(TensorSchema):
Jennifer Zhao's avatar
Jennifer Zhao committed
50
    """
51
52
53
54
55
56
57
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - c: Number of channels
        - h: Height of each image patch
        - w: Width of each image patch
        - bn: Batch size * number of images
Jennifer Zhao's avatar
Jennifer Zhao committed
58
59
    """

60
61
62
63
64
65
66
67
68
69
70
    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", 3, "h", "w"),
    ]

    num_patches: Annotated[
        torch.Tensor,
        TensorShape("bn"),
    ]
Jennifer Zhao's avatar
Jennifer Zhao committed
71
72
73
74
75
76
77
78


class AyaVisionMultiModalProjector(nn.Module):
    def __init__(self, config: AyaVisionConfig):
        super().__init__()
        self.config = config
        self.downsample_factor = config.downsample_factor
        self.alignment_intermediate_size = getattr(
79
80
81
82
83
84
            config, "alignment_intermediate_size", config.text_config.hidden_size
        )
        self.layernorm = nn.LayerNorm(
            config.vision_config.hidden_size * (config.downsample_factor**2),
            eps=config.adapter_layer_norm_eps,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
85
86
87
88
89
90
91
92
93

        self.linear_1 = nn.Linear(
            config.vision_config.hidden_size * (config.downsample_factor**2),
            self.alignment_intermediate_size,
            bias=True,
        )

        self.act = ACT2FN["silu"]  # SwiGLU uses SiLU activation
        # For SwiGLU, project down to half size since we split intermediate dim
94
95
96
97
98
        self.linear_2 = nn.Linear(
            self.alignment_intermediate_size // 2,
            config.text_config.hidden_size,
            bias=True,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
99
100
101
102
103
104
105
106
107
108
109
110
111

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        image_features = self.pixel_shuffle(image_features)
        image_features = self.layernorm(image_features)
        hidden_states = self.linear_1(image_features)

        # Split along last dimension and apply SwiGLU
        x, gate = hidden_states.chunk(2, dim=-1)
        hidden_states = self.act(gate) * x

        hidden_states = self.linear_2(hidden_states)
        return hidden_states

112
    def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor:  # B, S, D
Jennifer Zhao's avatar
Jennifer Zhao committed
113
114
        batch_size, seq_length, _ = image_features.shape
        height = width = int(seq_length**0.5)
115
116
117
        image_features = image_features.reshape(
            image_features.shape[0], width, height, -1
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
118
119
        channels = image_features.shape[-1]
        image_features = image_features.reshape(
120
121
122
123
124
            batch_size,
            width,
            int(height / self.downsample_factor),
            int(channels * self.downsample_factor),
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
125
126
        image_features = image_features.permute(0, 2, 1, 3)
        image_features = image_features.reshape(
127
128
129
130
131
            batch_size,
            int(height / self.downsample_factor),
            int(width / self.downsample_factor),
            -1,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
132
133
134
135
136
137
138
139
140
        image_features = image_features.permute(0, 2, 1, 3)
        return image_features


class AyaVisionProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> AyaVisionConfig:
        return self.ctx.get_hf_config(AyaVisionConfig)

    def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor:
141
        return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
142

143
144
    def get_image_processor(self, **kwargs: object) -> GotOcr2ImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
Jennifer Zhao's avatar
Jennifer Zhao committed
145

146
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Jennifer Zhao's avatar
Jennifer Zhao committed
147
148
149
150
        return {"image": None}

    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
151
152
        height = image_processor.size["height"]
        width = image_processor.size["width"]
Jennifer Zhao's avatar
Jennifer Zhao committed
153
        max_patches = image_processor.max_patches
154
        return ImageSize(height=height * max_patches, width=width * max_patches)
Jennifer Zhao's avatar
Jennifer Zhao committed
155

156
157
158
159
160
161
162
163
164
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
        size: dict,
        min_patches: int,
        max_patches: int,
    ) -> int:
Jennifer Zhao's avatar
Jennifer Zhao committed
165
166
167
168
169
170
171
        """
        Calculate the number of patches needed for a given image based on size
        constraints.  This method replicates and adjusts the logic from:
        transformers/models/got_ocr2/image_processing_got_ocr2
        """
        size = get_size_dict(size, default_to_square=False)
        num_columns, num_rows = get_optimal_tiled_canvas(
172
173
174
175
176
            (image_height, image_width),
            (size["height"], size["width"]),
            min_patches,
            max_patches,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
177
178
179
180
        num_blocks = num_columns * num_rows
        return num_blocks if num_blocks == 1 else num_blocks + 1


181
class AyaVisionDummyInputsBuilder(BaseDummyInputsBuilder[AyaVisionProcessingInfo]):
182
183
184
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

Jennifer Zhao's avatar
Jennifer Zhao committed
185
186
187
        processor = self.info.get_hf_processor()
        image_token = processor.image_token

188
189
190
191
192
193
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
194
        mm_options: Mapping[str, BaseDummyOptions],
195
    ) -> MultiModalDataDict:
Jennifer Zhao's avatar
Jennifer Zhao committed
196
        num_images = mm_counts.get("image", 0)
197
        image_size = self.info.get_image_size_with_most_features()
Jennifer Zhao's avatar
Jennifer Zhao committed
198

199
        image_overrides = mm_options.get("image")
200

201
        return {
202
203
204
205
206
207
            "image": self._get_dummy_images(
                width=image_size.width,
                height=image_size.height,
                num_images=num_images,
                overrides=image_overrides,
            )
Jennifer Zhao's avatar
Jennifer Zhao committed
208
209
210
        }


211
class AyaVisionMultiModalProcessor(BaseMultiModalProcessor[AyaVisionProcessingInfo]):
Jennifer Zhao's avatar
Jennifer Zhao committed
212
213
214
215
216
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
217
        tok_kwargs: Mapping[str, object],
Jennifer Zhao's avatar
Jennifer Zhao committed
218
219
220
221
222
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
223
            tok_kwargs,
Jennifer Zhao's avatar
Jennifer Zhao committed
224
225
226
227
228
        )
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_processor = hf_processor.image_processor

        # HF processor pops the `num_patches` kwarg, which is needed by vLLM
229
        if (images := mm_data.get("images")) is not None:
230
231
            mm_items = self.info.parse_mm_data({"image": images}, validate=False)
            parsed_images = mm_items.get_items("image", ImageProcessorItems)
Jennifer Zhao's avatar
Jennifer Zhao committed
232
            image_sizes = [
233
                parsed_images.get_image_size(i) for i in range(len(parsed_images))
Jennifer Zhao's avatar
Jennifer Zhao committed
234
            ]
235

Jennifer Zhao's avatar
Jennifer Zhao committed
236
237
238
239
240
241
            num_patches = [
                self.info.get_num_patches(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    size=image_processor.size,
                    min_patches=image_processor.min_patches,
242
243
                    max_patches=image_processor.max_patches,
                )
Jennifer Zhao's avatar
Jennifer Zhao committed
244
245
246
247
248
249
250
251
252
253
254
255
256
                for image_size in image_sizes
            ]
            processed_outputs["num_patches"] = torch.tensor(num_patches)

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))
        return dict(
257
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
Jennifer Zhao's avatar
Jennifer Zhao committed
258
259
260
261
262
263
264
265
            num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
266
        out_mm_kwargs: MultiModalKwargsItems,
Jennifer Zhao's avatar
Jennifer Zhao committed
267
268
269
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
270
        img_patch_token = hf_processor.img_patch_token
Jennifer Zhao's avatar
Jennifer Zhao committed
271
272
273
        image_processor = hf_processor.image_processor

        def get_replacement(item_idx: int):
274
            images = mm_items.get_items("image", ImageProcessorItems)
Jennifer Zhao's avatar
Jennifer Zhao committed
275
276
277
278
279
280
            image_size: ImageSize = images.get_image_size(item_idx)
            num_patches = self.info.get_num_patches(
                image_width=image_size.width,
                image_height=image_size.height,
                size=image_processor.size,
                min_patches=image_processor.min_patches,
281
282
283
284
285
                max_patches=image_processor.max_patches,
            )
            repl = hf_processor._prompt_split_image(num_patches=num_patches)

            return PromptUpdateDetails.select_text(repl, img_patch_token)
Jennifer Zhao's avatar
Jennifer Zhao committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int:
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
301
        return get_layer_index(feature_layers, num_hidden_layers)
Jennifer Zhao's avatar
Jennifer Zhao committed
302
303
    # If we have multiple feature layers, initialize up to the deepest m
    elif isinstance(feature_layers, (list, tuple)):
304
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
305
306
307
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
Jennifer Zhao's avatar
Jennifer Zhao committed
308
309
310
311
312


@MULTIMODAL_REGISTRY.register_processor(
    AyaVisionMultiModalProcessor,
    info=AyaVisionProcessingInfo,
313
314
315
    dummy_inputs=AyaVisionDummyInputsBuilder,
)
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
316
317
318
319
320
321
322
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
323
324
        }
    )
325

326
    @classmethod
327
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
328
329
330
331
332
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

Jennifer Zhao's avatar
Jennifer Zhao committed
333
334
335
336
337
338
339
340
341
342
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: AyaVisionConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        num_hidden_layers = _get_num_hidden_layers(config)
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = SiglipVisionModel(
                config.vision_config,
                quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.multi_modal_projector = AyaVisionMultiModalProjector(config)

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "model"),
                # Cohere2ForCausalLM and CohereForCausalLM are the same on vllm
                architectures=["Cohere2ForCausalLM"],
            )
Jennifer Zhao's avatar
Jennifer Zhao committed
360
361
362
363
364

    @property
    def dtype(self):
        return next(self.parameters()).dtype

365
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Jennifer Zhao's avatar
Jennifer Zhao committed
366
        loader = AutoWeightsLoader(self)
367
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
Jennifer Zhao's avatar
Jennifer Zhao committed
368

369
370
371
372
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
373
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
374
375
376
377
        return vision_tower(
            pixel_values.to(dtype=vision_tower.dtype),
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
378

379
380
381
    def _process_image_input(
        self, image_input: AyaVisionImagePixelInputs, **kwargs
    ) -> list[torch.Tensor]:
Jennifer Zhao's avatar
Jennifer Zhao committed
382
383
384
        pixel_values = image_input["pixel_values"]
        num_patches = image_input["num_patches"]
        image_features = self._image_pixels_to_features(
385
386
            self.vision_tower, pixel_values=pixel_values
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
387
        image_embeds = self.multi_modal_projector(image_features)
388
        return [e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())]
Jennifer Zhao's avatar
Jennifer Zhao committed
389
390

    def _parse_and_validate_image_input(
391
        self, **kwargs: object
392
    ) -> AyaVisionImagePixelInputs | None:
Jennifer Zhao's avatar
Jennifer Zhao committed
393
394
395
396
397
        pixel_values = kwargs.pop("pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)
        assert image_embeds is None, "Aya Vision does not support image_embeds."

398
399
        if pixel_values is None:
            return None
400

Jennifer Zhao's avatar
Jennifer Zhao committed
401
402
        return AyaVisionImagePixelInputs(
            type="pixel_values",
403
404
            pixel_values=pixel_values,
            num_patches=num_patches,
405
406
407
            resolve_bindings={
                "h": self.config.vision_config.image_size,
                "w": self.config.vision_config.image_size,
408
409
            },
        )
Jennifer Zhao's avatar
Jennifer Zhao committed
410

411
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
Jennifer Zhao's avatar
Jennifer Zhao committed
412
413
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
414
            return []
415
416

        return self._process_image_input(image_input, **kwargs)
Jennifer Zhao's avatar
Jennifer Zhao committed
417
418
419

    def forward(
        self,
420
        input_ids: torch.Tensor | None,
Jennifer Zhao's avatar
Jennifer Zhao committed
421
        positions: torch.Tensor,
422
423
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Jennifer Zhao's avatar
Jennifer Zhao committed
424
        **kwargs: object,
425
    ) -> torch.Tensor | IntermediateTensors:
Jennifer Zhao's avatar
Jennifer Zhao committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
440
    ) -> torch.Tensor | None:
441
        return self.language_model.compute_logits(hidden_states)