ovis.py 19.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/ovis/modeling_ovis.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19
20
"""PyTorch Ovis model."""

21
import math
22
from collections.abc import Iterable, Mapping
23
from typing import Annotated, Literal
24
25
26
27

import torch
import torch.nn as nn
from torch import Tensor
28
from torch.nn.functional import gumbel_softmax, pad, softmax
29
from transformers import BatchFeature, PretrainedConfig
30
31

from vllm.config import VllmConfig
32
from vllm.config.multimodal import BaseDummyOptions
33
from vllm.model_executor.layers.linear import ReplicatedLinear
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
36
from vllm.model_executor.models.aimv2 import AIMv2Model
from vllm.model_executor.models.siglip import SiglipVisionModel
37
38
39
40
41
42
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
43
from vllm.multimodal import MULTIMODAL_REGISTRY
44
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
49
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
50
51
52
53
54
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
)
55
56
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
57
from vllm.transformers_utils.processors.ovis import OvisProcessor
58
from vllm.utils.tensor_schema import TensorSchema, TensorShape
59

60
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
61
62
63

# Cannot find the following number from hf config.
IMAGE_TOKEN = "<image>"
64
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
65

66
67
68
69
70
71
72
73
74
75
IMAGE_PAD_TOKEN_MAP = {
    "gemma2": "<unused0>",
    "llama": "<|reserved_special_token_0|>",
    "qwen2": "<|image_pad|>",
}
IMAGE_PAD_TOKEN_ID_MAP = {
    "gemma2": 7,
    "llama": 128002,
    "qwen2": 151655,
}
76

77
78
79
80
81
82
83
84
85
86
87
88

def st_argmax(y_soft: torch.Tensor, dim: int):  # straight-through softmax
    index = y_soft.argmax(dim, keepdim=True)
    return torch.zeros_like(
        y_soft,
        memory_format=torch.legacy_contiguous_format,
    ).scatter_(dim, index, 1.0)


class VisualTokenizer(torch.nn.Module):
    def __init__(
        self,
89
        config: PretrainedConfig,
90
        quant_config: QuantizationConfig | None = None,
91
92
93
94
95
96
97
98
99
100
101
102
103
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.backbone = self._init_backbone(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.backbone",
        )
        # reserved tokens for IMAGE_INDICATORS
        head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
        self.head = torch.nn.Sequential(
            ReplicatedLinear(
104
105
106
                config.backbone_config.hidden_size
                * config.hidden_stride
                * config.hidden_stride,
107
108
109
                head_dim,
                bias=False,
                return_bias=False,
110
111
112
            ),
            torch.nn.LayerNorm(head_dim),
        )
113
114
115

    def _init_backbone(
        self,
116
        config: PretrainedConfig,
117
        quant_config: QuantizationConfig | None = None,
118
        prefix: str = "",
119
    ) -> nn.Module:
120
121
        model_type = config.backbone_config.model_type
        if model_type == "aimv2":
122
            # No post rms_norm in Ovis2's AIMv2 ViT.
123
124
125
            return AIMv2Model(
                config=config.backbone_config,
                quant_config=quant_config,
126
                require_post_norm=False,
127
128
129
130
131
132
133
134
                prefix=prefix,
            )
        elif model_type == "siglip_vision_model":
            return SiglipVisionModel(
                config=config.backbone_config,
                quant_config=quant_config,
                prefix=prefix,
            )
135
        raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
136
137

    @property
138
    def dtype(self) -> torch.dtype:
139
140
141
        return next(self.head.parameters()).dtype

    @property
142
    def device(self) -> torch.device:
143
144
        return next(self.head.parameters()).device

145
    def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
146
        if self.config.tokenize_function == "softmax":
147
            tokens = softmax(logits, dim=-1)
148
        elif self.config.tokenize_function == "gumbel_argmax":
149
            tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
150
        elif self.config.tokenize_function == "st_argmax":
151
152
153
            tokens = st_argmax(logits, dim=-1)
        else:
            raise ValueError(
154
155
156
                "Invalid `max_type`, expected softmax or gumbel_argmax "
                f"or st_argmax, but got {self.config.tokenize_function}"
            )
157
158
        return tokens

159
    def encode(self, pixel_values: torch.Tensor) -> torch.Tensor:
160
161
162
163
164
165
166
167
168
        features = self.backbone(pixel_values)
        if self.config.drop_cls_token:
            features = features[:, 1:, :]

        # merge number of `hidden_stride * hidden_stride` hidden states together
        # to reduce token sequence length
        # e.g., for hidden_stride=2, this leads to a token length reduction:
        # 1024 -> 256 for aimv2
        if self.config.hidden_stride > 1:
169
            # this `d` maybe different from the above `d`
170
171
172
            n, L, d = features.shape
            sqrt_l = int(L**0.5)
            assert sqrt_l**2 == L, (
173
174
                "The token sequence length should be a perfect square."
            )
175
            features = features.reshape(n, sqrt_l, sqrt_l, d)
176
177
178
            pl = (
                self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)
            ) % self.config.hidden_stride
179
180
            features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
            sqrt_l += pl
181
182
183
184
185
186
187
188
            features = features.reshape(
                n,
                sqrt_l // self.config.hidden_stride,
                self.config.hidden_stride,
                sqrt_l // self.config.hidden_stride,
                self.config.hidden_stride,
                d,
            )
189
190
191
192
193
194
            # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
            features = features.permute(0, 1, 3, 2, 4, 5)
            # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
            features = features.flatten(3)
            # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
            features = features.reshape(
195
196
                n, -1, self.config.hidden_stride * self.config.hidden_stride * d
            )
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

        return features

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
        features = self.encode(pixel_values)
        logits = self.head(features)
        tokens = self.tokenize(logits)
        # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
        # [BatchSize, #Token, 5], after which, tokens' shape should become
        # [BatchSize, #Token, VocabSize]
        tokens = torch.nn.functional.pad(
            tokens,
            (0, len(IMAGE_INDICATOR_IDS)),
            mode="constant",
            value=0,
        )
        return tokens


217
class OvisImagePatchInputs(TensorSchema):
218
    """
219
    Dimensions:
220
221
222
        - bnp: Batch size * number of images * number of patches
        - h: Height of each patch
        - w: Width of each patch
223
        - patch_indicators: Batch size * (number of patches + 1)
224
        - bn: Batch size * number of images
225
    """
226

227
    type: Literal["image_patches"]
228
    flat_data: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
229
    indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
230
    patches_per_image: Annotated[list[int], TensorShape("bn")]
231
    # This is used to restore the first two dimensions of `flat_data`.
232
233
234
235
236
237
238
239


class VisualEmbedding(torch.nn.Embedding):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, visual_tokens: Tensor) -> Tensor:
        if visual_tokens.dtype in [
240
241
242
243
244
            torch.int8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.long,
245
246
247
248
249
250
251
252
253
254
255
256
257
        ]:
            return super().forward(visual_tokens)
        return torch.matmul(visual_tokens, self.weight)

    @property
    def device(self):
        return self.weight.device

    @property
    def dtype(self):
        return self.weight.dtype


258
class OvisProcessingInfo(BaseProcessingInfo):
259
    def get_hf_processor(self, **kwargs: object):
260
261
262
263
        return self.ctx.get_hf_processor(
            OvisProcessor,
            image_pad_token=self.get_image_pad_token(),
            image_segment_len=self.get_image_segment_len(),
264
            **kwargs,
265
        )
266

267
268
269
270
271
272
273
274
    def get_image_segment_len(self) -> int:
        visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config
        image_size = visual_tokenizer_config.backbone_config.image_size
        patch_size = visual_tokenizer_config.backbone_config.patch_size
        hidden_stride = visual_tokenizer_config.hidden_stride
        patch_grid_length = math.ceil(image_size / patch_size)
        assert patch_grid_length % hidden_stride == 0, (
            f"patch_grid_length {patch_grid_length} is not divisible by "
275
276
            f"hidden_stride {hidden_stride}"
        )
277
        # minus 1 for presented image token
278
        return (patch_grid_length // hidden_stride) ** 2 - 1
279
280
281
282
283
284

    def get_image_pad_token(self) -> str:
        hf_text_config = self.get_hf_config().get_text_config()
        text_model_type = hf_text_config.model_type
        return IMAGE_PAD_TOKEN_MAP.get(text_model_type)

285
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
286
        return {"image": None}
287
288

    def get_image_size_with_most_features(self) -> ImageSize:
289
290
        height, width = self.get_hf_processor().get_image_size()
        hs = self.get_hf_config().visual_tokenizer_config.hidden_stride
291
        # NOTE(Isotr0py): 9 is `max_partition` hardcoded in original code
292
293
        # https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96
        return ImageSize(width=width * hs * 9, height=height * hs * 9)
294
295


296
class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]):
297
298
299
300
301
302
303
304
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return IMAGE_TOKEN * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
305
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
306
307
308
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

309
        target_width, target_height = self.info.get_image_size_with_most_features()
310

311
312
        image_overrides = mm_options.get("image") if mm_options else None

313
        mm_data = {
314
315
316
317
318
319
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
320
321
322
323
        }
        return mm_data


324
class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
325
326
327
328
329
    def image_indicators_to_visual_tokens(
        self,
        image_indicators: list[int],
    ) -> list[int]:
        """
330
        Filter image indicators placeholders and convert them to corresponding
331
332
333
334
335
336
337
338
339
        tokens in visual tokenizer.
        For example, [-301, -300, -302, -300, -303, -300, -304, -300, -305]
        should return [vocab_size-1, vocab_size-2, ..., vocab_size-5]
        """
        hf_config = self.info.get_hf_config()
        vte_vocab_size = hf_config.visual_tokenizer_config.vocab_size
        # -300 is image_atom token, filter them out
        return [vte_vocab_size + x + 300 for x in image_indicators if x < -300]

340
341
342
343
344
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
345
        tok_kwargs: Mapping[str, object],
346
347
    ) -> BatchFeature:
        if not mm_data:
348
349
350
            # Avoid warning from HF logger for text-only input
            tokenizer = self.info.get_tokenizer()
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
351
352
353
354
355
356
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
357
            tok_kwargs=tok_kwargs,
358
359
        )

360
361
362
363
364
365
366
367
368
        hf_processor = self.info.get_hf_processor()
        image_indicators = [
            hf_processor.construct_image_indicators(grid)
            for grid in processed_outputs["grids"]
        ]
        indicator_tokens = [
            self.image_indicators_to_visual_tokens(indicator)
            for indicator in image_indicators
        ]
369
        processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
370
371
372
373
374
375
376
377
378
379
380
381
382
        return processed_outputs

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        return prompt_tokens

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
383
384
385
386
387
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            grids=MultiModalFieldConfig.batched("image"),
            indicator_tokens=MultiModalFieldConfig.batched("image"),
        )
388
389
390
391
392

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
393
        out_mm_kwargs: MultiModalKwargsItems,
394
    ) -> list[PromptReplacement]:
395
396
397
        def get_replacement_ovis(item_idx: int):
            out_item = out_mm_kwargs["image"][item_idx]
            grid = out_item["grids"].data
398
399
400
401
402
403
404
405
406
407
408
409
410

            hf_processor = self.info.get_hf_processor()
            return hf_processor.construct_image_placeholders(grid)

        return [
            PromptReplacement(
                modality="image",
                target=IMAGE_TOKEN,
                replacement=get_replacement_ovis,
            ),
        ]


411
412
413
414
415
@MULTIMODAL_REGISTRY.register_processor(
    OvisMultiModalProcessor,
    info=OvisProcessingInfo,
    dummy_inputs=OvisDummyInputsBuilder,
)
416
class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
417
    @classmethod
418
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
419
420
421
422
423
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

424
425
426
427
428
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

429
        self.config: PretrainedConfig = config
430
431
432
433
434
        self.llm = init_vllm_registered_model(
            vllm_config=vllm_config.with_hf_config(config.get_text_config()),
            prefix=maybe_prefix(prefix, "llm"),
        )

435
        self.visual_tokenizer = VisualTokenizer(
436
            config=config.visual_tokenizer_config,
437
            quant_config=quant_config,
438
439
440
441
            prefix=f"{prefix}.visual_tokenizer",
        )

        self.vte = VisualEmbedding(
442
443
            self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size
        )
444

445
446
447
        text_model_type = self.config.get_text_config().model_type
        self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]

448
        self.make_empty_intermediate_tensors = (
449
450
            self.get_language_model().make_empty_intermediate_tensors
        )
451

452
    def _parse_and_validate_image_input(
453
        self, **kwargs: object
454
    ) -> OvisImagePatchInputs | None:
455
        pixel_values = kwargs.pop("pixel_values", None)
456
457
458
        indicator_tokens = kwargs.pop("indicator_tokens", None)

        if pixel_values is None and indicator_tokens is None:
459
460
            return None

461
        if pixel_values is not None and indicator_tokens is not None:
462
            if not isinstance(pixel_values, (torch.Tensor, list)):
463
464
465
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )
466

467
            if not isinstance(indicator_tokens, (torch.Tensor, list)):
468
469
470
471
                raise ValueError(
                    "Incorrect type of indicator_tokens. "
                    f"Got type: {type(pixel_values)}"
                )
472

473
            return OvisImagePatchInputs(
474
                type="image_patches",
475
476
477
                flat_data=flatten_bn(pixel_values, concat=True),
                patches_per_image=[x.shape[0] for x in pixel_values],
                indicator_tokens=flatten_bn(indicator_tokens, concat=True),
478
479
480
481
482
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
483
484
        self, image_input: OvisImagePatchInputs
    ) -> MultiModalEmbeddings:
485
486
        image_patches_flat = image_input["flat_data"]
        patches_per_image = image_input["patches_per_image"]
487
488
489
        indicator_tokens = image_input["indicator_tokens"]

        indicator_per_image = list(
490
491
            map(lambda x: x + 1 if x > 1 else x + 2, patches_per_image)
        )
492
493

        target_dtype = self.visual_tokenizer.dtype
494
        visual_tokens = self.visual_tokenizer(image_patches_flat.to(target_dtype))
495
496
        visual_embeds = self.vte(visual_tokens)  # 1:1 numeric eq.

497
        indicator_embeds = self.vte(indicator_tokens)
498
        indicator_embeds_per_image = indicator_embeds.split(indicator_per_image)
499
500
501

        visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
        vision_embeddings = []
502
503
504
        for indicator, visual in zip(
            indicator_embeds_per_image, visual_embeds_per_image
        ):
505
506
507
            vision_embeddings_per_image = []
            for i in range(visual.shape[0]):
                vision_embeddings_per_image.append(
508
509
510
511
                    torch.cat([indicator[i : i + 1], visual[i]], dim=0)
                )
            vision_embeddings_per_image.append(indicator[i + 1 :])
            vision_embeddings.append(torch.cat(vision_embeddings_per_image, dim=0))
512
513

        return tuple(vision_embeddings)
514

515
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
516
517
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
518
            return []
519
520
521
522
523
524
525
526
527

        image_features = self._process_image_input(image_input)

        return image_features

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
528
529
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
530
        **kwargs: object,
531
    ) -> torch.Tensor | IntermediateTensors:
532
533
534
        if intermediate_tensors is not None:
            inputs_embeds = None

535
        # up until here we have an inputs_embeds 100% numerical identity
536
537
538
539
540
541
542
543
544
545
546
547
        # between the OG HF Transformers implementation and ours
        hidden_states = self.llm(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
548
    ) -> torch.Tensor | None:
549
        logits = self.llm.compute_logits(hidden_states)
550
551
        return logits

552
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
553
554
555
556
557
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def get_language_model(self) -> torch.nn.Module:
        return self.llm