cohere2_vision.py 17.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from vllm/model_executor/models/aya_vision.py
"""Command-A-Vision (Cohere2Vision) multimodal model implementation for vLLM."""

from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, Optional, Union

import torch
from torch import nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.cohere2_vision import Cohere2VisionConfig
13
14
from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import (  # noqa: E501
    get_optimal_tiled_canvas)
15
16
17
18
from transformers.models.cohere2_vision.processing_cohere2_vision import (
    Cohere2VisionProcessor)

from vllm.config import VllmConfig
19
from vllm.config.multimodal import BaseDummyOptions
20
21
22
23
24
25
from vllm.model_executor.layers.activation import MulAndSilu
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
27
28
29
30
31
32
33
34
35
36
37
38
39
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo,
                                        MultiModalFieldConfig,
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
40
from .utils import (AutoWeightsLoader, WeightsMapper,
41
                    init_vllm_registered_model, maybe_prefix)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153


class Cohere2VisionImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - c: Number of channels
        - h: Height of each image patch
        - w: Width of each image patch
        - bn: Batch size * number of images
    """

    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", 3, "h", "w"),
    ]

    num_patches: Annotated[
        torch.Tensor,
        TensorShape("bn"),
    ]


class Cohere2VisionMultiModalProjector(nn.Module):
    """Multimodal projector that maps vision features to text embedding space.
    
    Uses pixel shuffle downsampling followed by SwiGLU activation.
    """

    def __init__(self, config: Cohere2VisionConfig, prefix: str = ""):
        super().__init__()
        self.downsample_factor = config.downsample_factor

        # Input dimension after pixel shuffle downsampling
        input_dim = config.vision_config.hidden_size * (
            config.downsample_factor**2)
        # MergedColumnParallelLinear expects the intermediate size to be a list
        # of sizes, so that it will load the weights as two separate linear
        # layers before applying any parallelism.
        # We need to divide the alignment intermediate size by 2 because
        # the weights are merged weights of two linear layers for SwiGLU.
        self.intermediate_size = config.alignment_intermediate_size // 2

        self.linear_1 = MergedColumnParallelLinear(
            input_dim,
            [self.intermediate_size] * 2,
            bias=True,
            return_bias=False,
            prefix=f"{prefix}.linear_1",
        )
        self.act = MulAndSilu()
        self.linear_2 = RowParallelLinear(
            self.intermediate_size,
            config.text_config.hidden_size,
            bias=True,
            return_bias=False,
            prefix=f"{prefix}.linear_2",
        )

    def forward(self, image_features):
        image_features = self.pixel_shuffle(image_features)
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states

    def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor:
        """Apply pixel shuffle downsampling to reduce spatial dimensions.
        
        Args:
            image_features: Input tensor of shape [B, S, D] where S = H*W
            
        Returns:
            Downsampled tensor with increased channel dimension
        """
        height = width = int(image_features.shape[1]**0.5)
        x = image_features.reshape(image_features.shape[0], width, height, -1)
        n, h, w, c = x.size()
        scale_factor = 1. / self.downsample_factor
        nh = int(h * scale_factor)
        nw = int(w * scale_factor)
        x = x.reshape(n, nh, self.downsample_factor, nw,
                      self.downsample_factor, c)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.reshape(n, nh, nw, -1)
        return x


class Cohere2VisionProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self) -> Cohere2VisionConfig:
        return self.ctx.get_hf_config(Cohere2VisionConfig)

    def get_hf_processor(self, **kwargs: object) -> Cohere2VisionProcessor:
        return self.ctx.get_hf_processor(Cohere2VisionProcessor, **kwargs)

    def get_image_processor(self, **kwargs: object):
        return self.get_hf_processor(**kwargs).image_processor

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
        height = image_processor.size['height']
        width = image_processor.size['width']
        max_patches = image_processor.max_patches
        return ImageSize(height=height * max_patches, width=width)

154
155
156
157
158
159
160
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Cohere2VisionProcessor],
    ) -> int:
161
162
163
164
        """
        Calculate the number of image patches for a given image.
        Uses the HF processor to determine the actual number of patches.
        """
165
166
167
168
169
170
171
        if processor is None:
            processor = self.get_hf_processor()

        image_processor = processor.image_processor

        # The current implementation of get_number_of_image_patches
        # is incorrect, so we patch it here.
172
173
        # TODO: Revert once
        # https://github.com/huggingface/transformers/pull/40312 is released.
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        # return image_processor.get_number_of_image_patches(image_height,
        #                                                    image_width, {})

        min_patches = image_processor.min_patches
        max_patches = image_processor.max_patches
        patch_size = image_processor.size
        crop_to_patches = image_processor.crop_to_patches

        if not crop_to_patches:
            return 1

        num_columns, num_rows = get_optimal_tiled_canvas(
            (image_height, image_width),
            (patch_size["height"], patch_size["width"]),
            min_patches,
            max_patches,
        )
        num_patches = num_columns * num_rows
        if num_patches > 1:
            num_patches += 1  # Thumbnail image

        return num_patches
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212


class Cohere2VisionDummyInputsBuilder(
        BaseDummyInputsBuilder[Cohere2VisionProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
213
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
214
215
216
217
218
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        image_size = \
            self.info.get_image_size_with_most_features()

219
220
        image_overrides = mm_options.get("image") if mm_options else None

221
222
223
224
        return {
            "image":
            self._get_dummy_images(width=image_size.width,
                                   height=image_size.height,
225
226
                                   num_images=num_images,
                                   overrides=image_overrides)
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        }


class Cohere2VisionMultiModalProcessor(
        BaseMultiModalProcessor[Cohere2VisionProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
            tok_kwargs,
        )

        # Ensure num_patches is available for proper tensor splitting
        if "num_patches" not in processed_outputs and (
                images := mm_data.get("images")) is not None:
250
251
            hf_processor = self.info.get_hf_processor(**mm_kwargs)

252
253
254
255
256
257
258
259
260
            # Fallback calculation if HF processor didn't provide num_patches
            parsed_images = self._get_data_parser().parse_mm_data({
                "image":
                images
            }).get_items("image", ImageProcessorItems)

            num_patches = [
                self.info.get_num_patches(
                    image_width=parsed_images.get_image_size(i).width,
261
262
263
                    image_height=parsed_images.get_image_size(i).height,
                    processor=hf_processor,
                ) for i in range(len(parsed_images))
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
            ]
            processed_outputs["num_patches"] = torch.tensor(num_patches)

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", num_patches),
            num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
286
        out_mm_kwargs: MultiModalKwargsItems,
287
288
289
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
290
        img_tokens_per_tile = int(hf_processor.patch_size**2)
291
292
293
294
295
        img_line_break_token = hf_processor.img_line_break_token
        boi_token = hf_processor.boi_token
        eoi_token = hf_processor.eoi_token

        def get_replacement(item_idx: int):
296
            images = mm_items.get_items("image", ImageProcessorItems)
297
298
            image_size: ImageSize = images.get_image_size(item_idx)

299
300
301
302
303
304
305
306
            num_patches = self.info.get_num_patches(
                image_width=image_size.width,
                image_height=image_size.height,
                processor=hf_processor,
            )
            patch_tokens = (image_token * img_tokens_per_tile +
                            img_line_break_token)
            repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}"
307

308
            return PromptUpdateDetails.select_text(repl, image_token)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


@MULTIMODAL_REGISTRY.register_processor(
    Cohere2VisionMultiModalProcessor,
    info=Cohere2VisionProcessingInfo,
    dummy_inputs=Cohere2VisionDummyInputsBuilder)
class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
                                            SupportsPP):
325
    merge_by_field_config = True
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        })

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: Cohere2VisionConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config
        self._patch_quant_config(config, quant_config)

        self.vision_tower = SiglipVisionModel(config.vision_config,
                                              quant_config,
                                              prefix=maybe_prefix(
                                                  prefix, "vision_tower"))
        self.vocab_size = config.text_config.vocab_size
        self.multi_modal_projector = \
            Cohere2VisionMultiModalProjector(
                config, prefix=maybe_prefix(prefix, "multi_modal_projector"))
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
357
            architectures=config.text_config.architectures)
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def _process_image_input(self, image_input: Cohere2VisionImagePixelInputs,
                             **kwargs) -> list[torch.Tensor]:
        """Process image pixels through vision tower and projector.
        
        Args:
            image_input: Validated image input containing pixel values and 
                         patch counts
            
        Returns:
            List of flattened image embeddings, one per image
        """
        assert self.vision_tower is not None, "Vision tower is required"

        pixel_values = image_input["pixel_values"]
        num_patches = image_input["num_patches"]

        # Extract visual features
        image_features = self.vision_tower(pixel_values)

        # Project to text embedding space
        image_embeds = self.multi_modal_projector(image_features)

        # Split and flatten embeddings per image
        return [
            e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())
        ]

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Cohere2VisionImagePixelInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)
        assert image_embeds is None, \
            "Cohere2Vision does not support image_embeds."

        if pixel_values is None:
            return None

        return Cohere2VisionImagePixelInputs(
            type="pixel_values",
408
409
            pixel_values=pixel_values,
            num_patches=num_patches,
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
            resolve_bindings={
                "h": self.config.vision_config.image_size,
                "w": self.config.vision_config.image_size,
            })

    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and (llm_quant_config
                                                              is not None):
                quant_config.modules_to_not_convert.append("vision_tower")

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        return self._process_image_input(image_input, **kwargs)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
461
        return self.language_model.compute_logits(hidden_states)