nvlm_d.py 7.44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py
# --------------------------------------------------------
# NVLM-D
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
10
11
from collections.abc import Mapping, Sequence
from typing import Optional
12

13
import torch
14
15
16
import torch.nn as nn
from transformers import PretrainedConfig

17
from vllm.model_executor.layers.quantization import QuantizationConfig
18
from vllm.multimodal import MULTIMODAL_REGISTRY
19
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
20
21
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   MultiModalDataItems)
22
23
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
24
25

from .intern_vit import InternVisionModel
26
27
28
29
from .internvl import (BaseInternVLDummyInputsBuilder,
                       BaseInternVLMultiModalProcessor,
                       BaseInternVLProcessingInfo, BaseInternVLProcessor,
                       InternVLChatModel)
30

31
IMG_PAD = "<|vision_pad|>"
32
33


34
class NVLMProcessor(BaseInternVLProcessor):
35

36
37
38
    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_PAD]
39

40
    def get_image_repl(
41
42
43
        self,
        feature_size: int,
        num_patches: Optional[int],
44
    ) -> PromptUpdateDetails[str]:
45
46
        if num_patches is None:
            raise NotImplementedError("Embedding inputs are not supported")
47

48
        tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)]
49
        if self.use_thumbnail:
50
            tile_pos_identifiers += ["<tile_global_thumbnail>"]
51

52
53
54
55
56
57
58
        context_size = feature_size // num_patches
        features = "".join(identifier + IMG_PAD * context_size
                           for identifier in tile_pos_identifiers)

        # We include the start and end as well because "<Image><tile" is
        # tokenized as ["<Image", "><", "tile"], resulting in assertion error
        # when trying to find "<tile" as a subsequence of "<Image><tile"
59
        repl = "<Image>" + features + "</Image>"
60

61
        return PromptUpdateDetails.select_text(repl, IMG_PAD)
62
63
64
65


class NVLMProcessingInfo(BaseInternVLProcessingInfo):

66
    def get_hf_processor(self, **kwargs: object) -> NVLMProcessor:
67
68
69
70
71
        return self.ctx.init_processor(
            NVLMProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
72
73
74
        )


75
76
class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]
                             ):
77

78
79
80
81
82
83
84
85
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        # The newline is necessary to separate ">" of the current item
        # and "<" of the next item
        return "<image>\n" * num_images

    def get_dummy_mm_data(
86
87
88
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
89
    ) -> MultiModalDataDict:
90
91
92
93
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

94
        return {
95
96
97
98
99
100
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

101

102
103
class NVLMMultiModalProcessor(
        BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
104

105
    def _get_prompt_updates(
106
107
108
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
109
        out_mm_kwargs: MultiModalKwargsItems,
110
    ) -> Sequence[PromptUpdate]:
111
112
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

113
114
115
        out_mm_data = out_mm_kwargs.get_data()
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
116
117
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
118
        elif "image_embeds" in out_mm_data:
119
120
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
121
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        else:
            image_num_patches = []

        def get_replacement_nvlm(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

143
144
            repl = hf_processor.get_image_repl(feature_size, num_patches)

145
            return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD)
146
147
148
149
150
151
152
153
154
155
156
157
158
159

        # See note in dummy data regarding why we have the extra newline
        return [
            PromptReplacement(
                modality="image",
                target="<image>\n",
                replacement=get_replacement_nvlm,
            )
        ]


@MULTIMODAL_REGISTRY.register_processor(NVLMMultiModalProcessor,
                                        info=NVLMProcessingInfo,
                                        dummy_inputs=NVLMDummyInputsBuilder)
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class NVLM_D_Model(InternVLChatModel):

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_intermediate_size = config.text_config.intermediate_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_intermediate_size,
                      bias=False),
            nn.GELU(),
            nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
        )

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1

            # We added additional dummy heads to the original num of heads to
            # make the number of heads divisible by 8.
            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                num_dummy_heads=7,
                prefix=prefix,
            )
        else:
            msg = "Monolith mode is not applicable to NVLM_D"
            raise NotImplementedError(msg)