ovis.py 20.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/ovis/modeling_ovis.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19
20
"""PyTorch Ovis model."""

21
import math
22
from collections.abc import Iterable, Mapping
23
from typing import Annotated, Literal
24
25
26
27

import torch
import torch.nn as nn
from torch import Tensor
28
from torch.nn.functional import gumbel_softmax, pad, softmax
29
from transformers import BatchFeature, PretrainedConfig
30
31

from vllm.config import VllmConfig
32
from vllm.config.multimodal import BaseDummyOptions
33
from vllm.model_executor.layers.linear import ReplicatedLinear
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
36
from vllm.model_executor.models.aimv2 import AIMv2Model
from vllm.model_executor.models.siglip import SiglipVisionModel
37
38
39
40
41
42
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
43
from vllm.multimodal import MULTIMODAL_REGISTRY
44
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
49
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
50
from vllm.multimodal.processing import (
51
    BaseDummyInputsBuilder,
52
53
54
55
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
)
56
from vllm.renderers import TokenizeParams
57
from vllm.sequence import IntermediateTensors
58
from vllm.transformers_utils.processors.ovis import OvisProcessor
59
from vllm.utils.tensor_schema import TensorSchema, TensorShape
60

61
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
62
63
64

# Cannot find the following number from hf config.
IMAGE_TOKEN = "<image>"
65
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
66

67
68
69
70
71
72
73
74
75
76
IMAGE_PAD_TOKEN_MAP = {
    "gemma2": "<unused0>",
    "llama": "<|reserved_special_token_0|>",
    "qwen2": "<|image_pad|>",
}
IMAGE_PAD_TOKEN_ID_MAP = {
    "gemma2": 7,
    "llama": 128002,
    "qwen2": 151655,
}
77

78
79
80
81
82
83
84
85
86
87
88
89

def st_argmax(y_soft: torch.Tensor, dim: int):  # straight-through softmax
    index = y_soft.argmax(dim, keepdim=True)
    return torch.zeros_like(
        y_soft,
        memory_format=torch.legacy_contiguous_format,
    ).scatter_(dim, index, 1.0)


class VisualTokenizer(torch.nn.Module):
    def __init__(
        self,
90
        config: PretrainedConfig,
91
        quant_config: QuantizationConfig | None = None,
92
93
94
95
96
97
98
99
100
101
102
103
104
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.backbone = self._init_backbone(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.backbone",
        )
        # reserved tokens for IMAGE_INDICATORS
        head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
        self.head = torch.nn.Sequential(
            ReplicatedLinear(
105
106
107
                config.backbone_config.hidden_size
                * config.hidden_stride
                * config.hidden_stride,
108
109
110
                head_dim,
                bias=False,
                return_bias=False,
111
112
113
            ),
            torch.nn.LayerNorm(head_dim),
        )
114
115
116

    def _init_backbone(
        self,
117
        config: PretrainedConfig,
118
        quant_config: QuantizationConfig | None = None,
119
        prefix: str = "",
120
    ) -> nn.Module:
121
122
        model_type = config.backbone_config.model_type
        if model_type == "aimv2":
123
            # No post rms_norm in Ovis2's AIMv2 ViT.
124
125
126
            return AIMv2Model(
                config=config.backbone_config,
                quant_config=quant_config,
127
                require_post_norm=False,
128
129
130
131
132
133
134
135
                prefix=prefix,
            )
        elif model_type == "siglip_vision_model":
            return SiglipVisionModel(
                config=config.backbone_config,
                quant_config=quant_config,
                prefix=prefix,
            )
136
        raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
137
138

    @property
139
    def dtype(self) -> torch.dtype:
140
141
142
        return next(self.head.parameters()).dtype

    @property
143
    def device(self) -> torch.device:
144
145
        return next(self.head.parameters()).device

146
    def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
147
        if self.config.tokenize_function == "softmax":
148
            tokens = softmax(logits, dim=-1)
149
        elif self.config.tokenize_function == "gumbel_argmax":
150
            tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
151
        elif self.config.tokenize_function == "st_argmax":
152
153
154
            tokens = st_argmax(logits, dim=-1)
        else:
            raise ValueError(
155
156
157
                "Invalid `max_type`, expected softmax or gumbel_argmax "
                f"or st_argmax, but got {self.config.tokenize_function}"
            )
158
159
        return tokens

160
    def encode(self, pixel_values: torch.Tensor) -> torch.Tensor:
161
162
163
164
165
166
167
168
169
        features = self.backbone(pixel_values)
        if self.config.drop_cls_token:
            features = features[:, 1:, :]

        # merge number of `hidden_stride * hidden_stride` hidden states together
        # to reduce token sequence length
        # e.g., for hidden_stride=2, this leads to a token length reduction:
        # 1024 -> 256 for aimv2
        if self.config.hidden_stride > 1:
170
            # this `d` maybe different from the above `d`
171
172
173
            n, L, d = features.shape
            sqrt_l = int(L**0.5)
            assert sqrt_l**2 == L, (
174
175
                "The token sequence length should be a perfect square."
            )
176
            features = features.reshape(n, sqrt_l, sqrt_l, d)
177
178
179
            pl = (
                self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)
            ) % self.config.hidden_stride
180
181
            features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
            sqrt_l += pl
182
183
184
185
186
187
188
189
            features = features.reshape(
                n,
                sqrt_l // self.config.hidden_stride,
                self.config.hidden_stride,
                sqrt_l // self.config.hidden_stride,
                self.config.hidden_stride,
                d,
            )
190
191
192
193
194
195
            # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
            features = features.permute(0, 1, 3, 2, 4, 5)
            # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
            features = features.flatten(3)
            # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
            features = features.reshape(
196
197
                n, -1, self.config.hidden_stride * self.config.hidden_stride * d
            )
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

        return features

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
        features = self.encode(pixel_values)
        logits = self.head(features)
        tokens = self.tokenize(logits)
        # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
        # [BatchSize, #Token, 5], after which, tokens' shape should become
        # [BatchSize, #Token, VocabSize]
        tokens = torch.nn.functional.pad(
            tokens,
            (0, len(IMAGE_INDICATOR_IDS)),
            mode="constant",
            value=0,
        )
        return tokens


218
class OvisImagePatchInputs(TensorSchema):
219
    """
220
    Dimensions:
221
222
223
        - bnp: Batch size * number of images * number of patches
        - h: Height of each patch
        - w: Width of each patch
224
        - patch_indicators: Batch size * (number of patches + 1)
225
        - bn: Batch size * number of images
226
    """
227

228
    type: Literal["image_patches"]
229
    flat_data: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
230
    indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
231
    patches_per_image: Annotated[list[int], TensorShape("bn")]
232
    # This is used to restore the first two dimensions of `flat_data`.
233
234
235
236
237
238
239
240


class VisualEmbedding(torch.nn.Embedding):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, visual_tokens: Tensor) -> Tensor:
        if visual_tokens.dtype in [
241
242
243
244
245
            torch.int8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.long,
246
247
248
249
250
251
252
253
254
255
256
257
258
        ]:
            return super().forward(visual_tokens)
        return torch.matmul(visual_tokens, self.weight)

    @property
    def device(self):
        return self.weight.device

    @property
    def dtype(self):
        return self.weight.dtype


259
class OvisProcessingInfo(BaseProcessingInfo):
260
    def get_hf_processor(self, **kwargs: object):
261
262
263
264
        return self.ctx.get_hf_processor(
            OvisProcessor,
            image_pad_token=self.get_image_pad_token(),
            image_segment_len=self.get_image_segment_len(),
265
            **kwargs,
266
        )
267

268
269
270
    def get_default_tok_params(self) -> TokenizeParams:
        return super().get_default_tok_params().with_kwargs(add_special_tokens=False)

271
272
273
274
275
276
277
278
    def get_image_segment_len(self) -> int:
        visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config
        image_size = visual_tokenizer_config.backbone_config.image_size
        patch_size = visual_tokenizer_config.backbone_config.patch_size
        hidden_stride = visual_tokenizer_config.hidden_stride
        patch_grid_length = math.ceil(image_size / patch_size)
        assert patch_grid_length % hidden_stride == 0, (
            f"patch_grid_length {patch_grid_length} is not divisible by "
279
280
            f"hidden_stride {hidden_stride}"
        )
281
        # minus 1 for presented image token
282
        return (patch_grid_length // hidden_stride) ** 2 - 1
283
284
285
286
287
288

    def get_image_pad_token(self) -> str:
        hf_text_config = self.get_hf_config().get_text_config()
        text_model_type = hf_text_config.model_type
        return IMAGE_PAD_TOKEN_MAP.get(text_model_type)

289
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
290
        return {"image": None}
291
292

    def get_image_size_with_most_features(self) -> ImageSize:
293
294
        height, width = self.get_hf_processor().get_image_size()
        hs = self.get_hf_config().visual_tokenizer_config.hidden_stride
295
        # NOTE(Isotr0py): 9 is `max_partition` hardcoded in original code
296
297
        # https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96
        return ImageSize(width=width * hs * 9, height=height * hs * 9)
298
299


300
class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]):
301
302
303
304
305
306
307
308
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return IMAGE_TOKEN * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
309
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
310
        mm_processor_kwargs: Mapping[str, object] | None = None,
311
312
313
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

314
        target_width, target_height = self.info.get_image_size_with_most_features()
315

316
317
        image_overrides = mm_options.get("image") if mm_options else None

318
        mm_data = {
319
320
321
322
323
324
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
325
326
327
328
        }
        return mm_data


329
class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
330
331
332
333
334
    def image_indicators_to_visual_tokens(
        self,
        image_indicators: list[int],
    ) -> list[int]:
        """
335
        Filter image indicators placeholders and convert them to corresponding
336
337
338
339
340
341
342
343
344
        tokens in visual tokenizer.
        For example, [-301, -300, -302, -300, -303, -300, -304, -300, -305]
        should return [vocab_size-1, vocab_size-2, ..., vocab_size-5]
        """
        hf_config = self.info.get_hf_config()
        vte_vocab_size = hf_config.visual_tokenizer_config.vocab_size
        # -300 is image_atom token, filter them out
        return [vte_vocab_size + x + 300 for x in image_indicators if x < -300]

345
346
347
348
349
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
350
        tok_kwargs: Mapping[str, object],
351
352
    ) -> BatchFeature:
        if not mm_data:
353
354
355
            # Avoid warning from HF logger for text-only input
            tokenizer = self.info.get_tokenizer()
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
356
357
358
359
360
361
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
362
            tok_kwargs=tok_kwargs,
363
364
        )

365
366
367
368
369
370
371
372
373
        hf_processor = self.info.get_hf_processor()
        image_indicators = [
            hf_processor.construct_image_indicators(grid)
            for grid in processed_outputs["grids"]
        ]
        indicator_tokens = [
            self.image_indicators_to_visual_tokens(indicator)
            for indicator in image_indicators
        ]
374
        processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
375
376
377
378
379
380
381
382
383
384
385
386
387
        return processed_outputs

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        return prompt_tokens

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
388
389
390
391
392
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            grids=MultiModalFieldConfig.batched("image"),
            indicator_tokens=MultiModalFieldConfig.batched("image"),
        )
393
394
395
396
397

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
398
        out_mm_kwargs: MultiModalKwargsItems,
399
    ) -> list[PromptReplacement]:
400
401
402
        def get_replacement_ovis(item_idx: int):
            out_item = out_mm_kwargs["image"][item_idx]
            grid = out_item["grids"].data
403
404
405
406
407
408
409
410
411
412
413
414
415

            hf_processor = self.info.get_hf_processor()
            return hf_processor.construct_image_placeholders(grid)

        return [
            PromptReplacement(
                modality="image",
                target=IMAGE_TOKEN,
                replacement=get_replacement_ovis,
            ),
        ]


416
417
418
419
420
@MULTIMODAL_REGISTRY.register_processor(
    OvisMultiModalProcessor,
    info=OvisProcessingInfo,
    dummy_inputs=OvisDummyInputsBuilder,
)
421
class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
422
    @classmethod
423
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
424
        if modality.startswith("image"):
425
            return IMAGE_TOKEN
426
427
428

        raise ValueError("Only image modality is supported")

429
430
431
432
433
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

434
        self.config: PretrainedConfig = config
435

436
437
438
439
440
        with self._mark_language_model(vllm_config):
            self.llm = init_vllm_registered_model(
                vllm_config=vllm_config.with_hf_config(config.get_text_config()),
                prefix=maybe_prefix(prefix, "llm"),
            )
441

442
443
444
445
446
447
448
449
450
        with self._mark_tower_model(vllm_config, "image"):
            self.visual_tokenizer = VisualTokenizer(
                config=config.visual_tokenizer_config,
                quant_config=quant_config,
                prefix=f"{prefix}.visual_tokenizer",
            )
            self.vte = VisualEmbedding(
                self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size
            )
451

452
453
454
        text_model_type = self.config.get_text_config().model_type
        self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]

455
        self.make_empty_intermediate_tensors = (
456
457
            self.get_language_model().make_empty_intermediate_tensors
        )
458

459
    def _parse_and_validate_image_input(
460
        self, **kwargs: object
461
    ) -> OvisImagePatchInputs | None:
462
        pixel_values = kwargs.pop("pixel_values", None)
463
464
465
        indicator_tokens = kwargs.pop("indicator_tokens", None)

        if pixel_values is None and indicator_tokens is None:
466
467
            return None

468
        if pixel_values is not None and indicator_tokens is not None:
469
            if not isinstance(pixel_values, (torch.Tensor, list)):
470
471
472
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )
473

474
            if not isinstance(indicator_tokens, (torch.Tensor, list)):
475
476
477
478
                raise ValueError(
                    "Incorrect type of indicator_tokens. "
                    f"Got type: {type(pixel_values)}"
                )
479

480
            return OvisImagePatchInputs(
481
                type="image_patches",
482
483
484
                flat_data=flatten_bn(pixel_values, concat=True),
                patches_per_image=[x.shape[0] for x in pixel_values],
                indicator_tokens=flatten_bn(indicator_tokens, concat=True),
485
486
487
488
489
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
490
491
        self, image_input: OvisImagePatchInputs
    ) -> MultiModalEmbeddings:
492
493
        image_patches_flat = image_input["flat_data"]
        patches_per_image = image_input["patches_per_image"]
494
495
496
        indicator_tokens = image_input["indicator_tokens"]

        indicator_per_image = list(
497
498
            map(lambda x: x + 1 if x > 1 else x + 2, patches_per_image)
        )
499
500

        target_dtype = self.visual_tokenizer.dtype
501
        visual_tokens = self.visual_tokenizer(image_patches_flat.to(target_dtype))
502
503
        visual_embeds = self.vte(visual_tokens)  # 1:1 numeric eq.

504
        indicator_embeds = self.vte(indicator_tokens)
505
        indicator_embeds_per_image = indicator_embeds.split(indicator_per_image)
506
507
508

        visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
        vision_embeddings = []
509
510
511
        for indicator, visual in zip(
            indicator_embeds_per_image, visual_embeds_per_image
        ):
512
513
514
            vision_embeddings_per_image = []
            for i in range(visual.shape[0]):
                vision_embeddings_per_image.append(
515
516
517
518
                    torch.cat([indicator[i : i + 1], visual[i]], dim=0)
                )
            vision_embeddings_per_image.append(indicator[i + 1 :])
            vision_embeddings.append(torch.cat(vision_embeddings_per_image, dim=0))
519
520

        return tuple(vision_embeddings)
521

522
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
523
524
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
525
            return []
526
527
528
529
530
531
532

        image_features = self._process_image_input(image_input)

        return image_features

    def forward(
        self,
533
        input_ids: torch.Tensor | None,
534
        positions: torch.Tensor,
535
536
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
537
        **kwargs: object,
538
    ) -> torch.Tensor | IntermediateTensors:
539
540
541
        if intermediate_tensors is not None:
            inputs_embeds = None

542
        # up until here we have an inputs_embeds 100% numerical identity
543
544
545
546
547
548
549
550
551
552
553
554
        # between the OG HF Transformers implementation and ours
        hidden_states = self.llm(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
555
    ) -> torch.Tensor | None:
556
        return self.llm.compute_logits(hidden_states)
557

558
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
559
560
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)