cohere2_vision.py 15.8 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from vllm/model_executor/models/aya_vision.py
"""Command-A-Vision (Cohere2Vision) multimodal model implementation for vLLM."""

from collections.abc import Iterable, Mapping, Sequence
7
from typing import Annotated, Literal
8
9
10
11
12

import torch
from torch import nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.cohere2_vision import Cohere2VisionConfig
13
from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import (  # noqa: E501
14
    Cohere2VisionImageProcessorFast,
15
)
16
from transformers.models.cohere2_vision.processing_cohere2_vision import (
17
18
    Cohere2VisionProcessor,
)
19
20

from vllm.config import VllmConfig
21
from vllm.config.multimodal import BaseDummyOptions
22
from vllm.inputs import MultiModalDataDict
23
from vllm.model_executor.layers.activation import MulAndSilu
24
25
26
27
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    RowParallelLinear,
)
28
29
30
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
31
32
33
34
from vllm.multimodal.inputs import (
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
35
36
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
37
    BaseDummyInputsBuilder,
38
39
40
41
42
43
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
44
45
46
47
48
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
49
50
51
52
53
54
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


class Cohere2VisionImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - c: Number of channels
        - h: Height of each image patch
        - w: Width of each image patch
        - bn: Batch size * number of images
    """

    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", 3, "h", "w"),
    ]

    num_patches: Annotated[
        torch.Tensor,
        TensorShape("bn"),
    ]


class Cohere2VisionMultiModalProjector(nn.Module):
    """Multimodal projector that maps vision features to text embedding space.
83

84
85
86
87
88
89
90
91
    Uses pixel shuffle downsampling followed by SwiGLU activation.
    """

    def __init__(self, config: Cohere2VisionConfig, prefix: str = ""):
        super().__init__()
        self.downsample_factor = config.downsample_factor

        # Input dimension after pixel shuffle downsampling
92
        input_dim = config.vision_config.hidden_size * (config.downsample_factor**2)
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        # MergedColumnParallelLinear expects the intermediate size to be a list
        # of sizes, so that it will load the weights as two separate linear
        # layers before applying any parallelism.
        # We need to divide the alignment intermediate size by 2 because
        # the weights are merged weights of two linear layers for SwiGLU.
        self.intermediate_size = config.alignment_intermediate_size // 2

        self.linear_1 = MergedColumnParallelLinear(
            input_dim,
            [self.intermediate_size] * 2,
            bias=True,
            return_bias=False,
            prefix=f"{prefix}.linear_1",
        )
        self.act = MulAndSilu()
        self.linear_2 = RowParallelLinear(
            self.intermediate_size,
            config.text_config.hidden_size,
            bias=True,
            return_bias=False,
            prefix=f"{prefix}.linear_2",
        )

    def forward(self, image_features):
        image_features = self.pixel_shuffle(image_features)
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states

    def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor:
        """Apply pixel shuffle downsampling to reduce spatial dimensions.
125

126
127
        Args:
            image_features: Input tensor of shape [B, S, D] where S = H*W
128

129
130
131
        Returns:
            Downsampled tensor with increased channel dimension
        """
132
        height = width = int(image_features.shape[1] ** 0.5)
133
134
        x = image_features.reshape(image_features.shape[0], width, height, -1)
        n, h, w, c = x.size()
135
        scale_factor = 1.0 / self.downsample_factor
136
137
        nh = int(h * scale_factor)
        nw = int(w * scale_factor)
138
        x = x.reshape(n, nh, self.downsample_factor, nw, self.downsample_factor, c)
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.reshape(n, nh, nw, -1)
        return x


class Cohere2VisionProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> Cohere2VisionConfig:
        return self.ctx.get_hf_config(Cohere2VisionConfig)

    def get_hf_processor(self, **kwargs: object) -> Cohere2VisionProcessor:
        return self.ctx.get_hf_processor(Cohere2VisionProcessor, **kwargs)

    def get_image_processor(self, **kwargs: object):
        return self.get_hf_processor(**kwargs).image_processor

154
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
155
156
157
158
        return {"image": None}

    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
159
160
        height = image_processor.size["height"]
        width = image_processor.size["width"]
161
162
163
        max_patches = image_processor.max_patches
        return ImageSize(height=height * max_patches, width=width)

164
165
166
167
168
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
169
170
        processor: Cohere2VisionProcessor,
        mm_kwargs: Mapping[str, object],
171
    ) -> int:
172
173
174
175
        """
        Calculate the number of image patches for a given image.
        Uses the HF processor to determine the actual number of patches.
        """
176
        image_processor: Cohere2VisionImageProcessorFast = processor.image_processor
177

178
179
180
181
        return image_processor.get_number_of_image_patches(
            image_height,
            image_width,
            self.ctx.get_merged_mm_kwargs(mm_kwargs),
182
        )
183
184
185


class Cohere2VisionDummyInputsBuilder(
186
187
    BaseDummyInputsBuilder[Cohere2VisionProcessingInfo]
):
188
189
190
191
192
193
194
195
196
197
198
199
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
200
        mm_options: Mapping[str, BaseDummyOptions],
201
202
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
203
        image_size = self.info.get_image_size_with_most_features()
204

205
        image_overrides = mm_options.get("image")
206

207
        return {
208
209
210
211
212
213
            "image": self._get_dummy_images(
                width=image_size.width,
                height=image_size.height,
                num_images=num_images,
                overrides=image_overrides,
            )
214
215
216
217
        }


class Cohere2VisionMultiModalProcessor(
218
219
    BaseMultiModalProcessor[Cohere2VisionProcessingInfo]
):
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
            tok_kwargs,
        )

        # Ensure num_patches is available for proper tensor splitting
235
236
237
238
        if (
            "num_patches" not in processed_outputs
            and (images := mm_data.get("images")) is not None
        ):
239
240
            hf_processor = self.info.get_hf_processor(**mm_kwargs)

241
            # Fallback calculation if HF processor didn't provide num_patches
242
243
            mm_items = self.info.parse_mm_data({"image": images}, validate=False)
            parsed_images = mm_items.get_items("image", ImageProcessorItems)
244
245
246
247

            num_patches = [
                self.info.get_num_patches(
                    image_width=parsed_images.get_image_size(i).width,
248
249
                    image_height=parsed_images.get_image_size(i).height,
                    processor=hf_processor,
250
                    mm_kwargs=mm_kwargs,
251
252
                )
                for i in range(len(parsed_images))
253
254
255
256
257
258
259
260
261
262
263
264
            ]
            processed_outputs["num_patches"] = torch.tensor(num_patches)

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))
        return dict(
265
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
266
267
268
269
270
271
272
273
            num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
274
        out_mm_kwargs: MultiModalKwargsItems,
275
276
277
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
278
        img_tokens_per_tile = int(hf_processor.patch_size**2)
279
280
281
282
283
        img_line_break_token = hf_processor.img_line_break_token
        boi_token = hf_processor.boi_token
        eoi_token = hf_processor.eoi_token

        def get_replacement(item_idx: int):
284
            images = mm_items.get_items("image", ImageProcessorItems)
285
286
            image_size: ImageSize = images.get_image_size(item_idx)

287
288
289
290
            num_patches = self.info.get_num_patches(
                image_width=image_size.width,
                image_height=image_size.height,
                processor=hf_processor,
291
                mm_kwargs=hf_processor_mm_kwargs,
292
            )
293
            patch_tokens = image_token * img_tokens_per_tile + img_line_break_token
294
            repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}"
295

296
            return PromptUpdateDetails.select_text(repl, image_token)
297
298
299
300
301
302
303
304
305
306
307
308
309

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


@MULTIMODAL_REGISTRY.register_processor(
    Cohere2VisionMultiModalProcessor,
    info=Cohere2VisionProcessingInfo,
310
311
312
    dummy_inputs=Cohere2VisionDummyInputsBuilder,
)
class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
313
314
315
316
317
318
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
319
320
        }
    )
321
322
323
324
325
326
327
328
329
330
331

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: Cohere2VisionConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config
        self._patch_quant_config(config, quant_config)

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = SiglipVisionModel(
                config.vision_config,
                quant_config,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
            self.multi_modal_projector = Cohere2VisionMultiModalProjector(
                config, prefix=maybe_prefix(prefix, "multi_modal_projector")
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=config.text_config.architectures,
            )
349
350
351
352
353

    @property
    def dtype(self):
        return next(self.parameters()).dtype

354
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
355
356
357
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

358
359
360
    def _process_image_input(
        self, image_input: Cohere2VisionImagePixelInputs, **kwargs
    ) -> list[torch.Tensor]:
361
        """Process image pixels through vision tower and projector.
362

363
        Args:
364
            image_input: Validated image input containing pixel values and
365
                         patch counts
366

367
368
369
370
371
372
373
374
375
376
377
378
379
        Returns:
            List of flattened image embeddings, one per image
        """
        pixel_values = image_input["pixel_values"]
        num_patches = image_input["num_patches"]

        # Extract visual features
        image_features = self.vision_tower(pixel_values)

        # Project to text embedding space
        image_embeds = self.multi_modal_projector(image_features)

        # Split and flatten embeddings per image
380
        return [e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())]
381
382

    def _parse_and_validate_image_input(
383
        self, **kwargs: object
384
    ) -> Cohere2VisionImagePixelInputs | None:
385
386
387
        pixel_values = kwargs.pop("pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)
388
        assert image_embeds is None, "Cohere2Vision does not support image_embeds."
389
390
391
392
393
394

        if pixel_values is None:
            return None

        return Cohere2VisionImagePixelInputs(
            type="pixel_values",
395
396
            pixel_values=pixel_values,
            num_patches=num_patches,
397
398
399
            resolve_bindings={
                "h": self.config.vision_config.image_size,
                "w": self.config.vision_config.image_size,
400
401
            },
        )
402

403
404
405
    def _patch_quant_config(
        self, config: PretrainedConfig, quant_config: QuantizationConfig
    ):
406
407
408
409
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
410
411
412
413
            llm_quant_config = getattr(text_config, "quantization_config", None)
            if (not quant_config.modules_to_not_convert) and (
                llm_quant_config is not None
            ):
414
415
                quant_config.modules_to_not_convert.append("vision_tower")

416
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
417
418
419
420
421
422
423
424
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        return self._process_image_input(image_input, **kwargs)

    def forward(
        self,
425
        input_ids: torch.Tensor | None,
426
        positions: torch.Tensor,
427
428
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
429
        **kwargs: object,
430
    ) -> torch.Tensor | IntermediateTensors:
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
445
    ) -> torch.Tensor | None:
446
        return self.language_model.compute_logits(hidden_states)