nvlm_d.py 7.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py
# --------------------------------------------------------
# NVLM-D
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
10
from collections.abc import Mapping, Sequence
11

12
import torch
13
14
15
import torch.nn as nn
from transformers import PretrainedConfig

16
from vllm.config.multimodal import BaseDummyOptions
17
from vllm.model_executor.layers.quantization import QuantizationConfig
18
from vllm.multimodal import MULTIMODAL_REGISTRY
19
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
20
21
22
23
24
25
26
27
28
29
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
30
31

from .intern_vit import InternVisionModel
32
33
34
35
36
37
38
from .internvl import (
    BaseInternVLDummyInputsBuilder,
    BaseInternVLMultiModalProcessor,
    BaseInternVLProcessingInfo,
    BaseInternVLProcessor,
    InternVLChatModel,
)
39

40
IMG_PAD = "<|vision_pad|>"
41
42


43
44
45
46
class NVLMProcessor(BaseInternVLProcessor):
    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_PAD]
47

48
    def get_image_repl(
49
50
        self,
        feature_size: int,
51
        num_patches: int | None,
52
    ) -> PromptUpdateDetails[str]:
53
54
        if num_patches is None:
            raise NotImplementedError("Embedding inputs are not supported")
55

56
        tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)]
57
        if self.use_thumbnail:
58
            tile_pos_identifiers += ["<tile_global_thumbnail>"]
59

60
        context_size = feature_size // num_patches
61
62
63
        features = "".join(
            identifier + IMG_PAD * context_size for identifier in tile_pos_identifiers
        )
64
65
66
67

        # We include the start and end as well because "<Image><tile" is
        # tokenized as ["<Image", "><", "tile"], resulting in assertion error
        # when trying to find "<tile" as a subsequence of "<Image><tile"
68
        repl = "<Image>" + features + "</Image>"
69

70
        return PromptUpdateDetails.select_text(repl, IMG_PAD)
71
72
73


class NVLMProcessingInfo(BaseInternVLProcessingInfo):
74
    def get_hf_processor(self, **kwargs: object) -> NVLMProcessor:
75
76
77
78
79
        return self.ctx.init_processor(
            NVLMProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
80
81
82
        )


83
class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]):
84
85
86
87
88
89
90
91
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        # The newline is necessary to separate ">" of the current item
        # and "<" of the next item
        return "<image>\n" * num_images

    def get_dummy_mm_data(
92
93
94
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
95
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
96
        mm_processor_kwargs: Mapping[str, object] | None = None,
97
    ) -> MultiModalDataDict:
98
        target_width, target_height = self.info.get_image_size_with_most_features()
99
100
        num_images = mm_counts.get("image", 0)

101
102
        image_overrides = mm_options.get("image") if mm_options else None

103
        return {
104
105
106
107
108
109
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
110
111
        }

112

113
class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
114
    def _get_prompt_updates(
115
116
117
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
118
        out_mm_kwargs: MultiModalKwargsItems,
119
    ) -> Sequence[PromptUpdate]:
120
121
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

122
123
124
        out_mm_data = out_mm_kwargs.get_data()
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
125
126
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
127
        elif "image_embeds" in out_mm_data:
128
129
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
130
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
131
132
133
134
135
        else:
            image_num_patches = []

        def get_replacement_nvlm(item_idx: int):
            images = mm_items.get_items(
136
137
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

153
154
            repl = hf_processor.get_image_repl(feature_size, num_patches)

155
            return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD)
156
157
158
159
160
161
162
163
164
165
166

        # See note in dummy data regarding why we have the extra newline
        return [
            PromptReplacement(
                modality="image",
                target="<image>\n",
                replacement=get_replacement_nvlm,
            )
        ]


167
168
169
170
171
@MULTIMODAL_REGISTRY.register_processor(
    NVLMMultiModalProcessor,
    info=NVLMProcessingInfo,
    dummy_inputs=NVLMDummyInputsBuilder,
)
172
class NVLM_D_Model(InternVLChatModel):
173
    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
174
175
176
177
178
        vit_hidden_size = config.vision_config.hidden_size
        llm_intermediate_size = config.text_config.intermediate_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
179
180
181
182
183
184
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            nn.Linear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
                llm_intermediate_size,
                bias=False,
            ),
185
186
187
188
            nn.GELU(),
            nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
        )

189
190
191
    def _init_vision_model(
        self,
        config: PretrainedConfig,
192
        quant_config: QuantizationConfig | None,
193
194
195
196
197
198
199
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
200
201
202
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            else:
                num_hidden_layers = vision_feature_layer + 1

            # We added additional dummy heads to the original num of heads to
            # make the number of heads divisible by 8.
            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                num_dummy_heads=7,
                prefix=prefix,
            )
        else:
            msg = "Monolith mode is not applicable to NVLM_D"
            raise NotImplementedError(msg)