nvlm_d.py 7.81 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py
# --------------------------------------------------------
# NVLM-D
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
9
10
from collections.abc import Mapping, Sequence
from typing import Optional
11

12
import torch
13
14
15
import torch.nn as nn
from transformers import PretrainedConfig

16
from vllm.model_executor.layers.quantization import QuantizationConfig
17
from vllm.multimodal import MULTIMODAL_REGISTRY
18
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
19
20
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   MultiModalDataItems)
21
22
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
23
24

from .intern_vit import InternVisionModel
25
26
27
28
from .internvl import (BaseInternVLDummyInputsBuilder,
                       BaseInternVLMultiModalProcessor,
                       BaseInternVLProcessingInfo, BaseInternVLProcessor,
                       InternVLChatModel)
29

30
IMG_PAD = "<|vision_pad|>"
31
32


33
class NVLMProcessor(BaseInternVLProcessor):
34

35
36
37
    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_PAD]
38

39
    def get_image_repl(
40
41
42
        self,
        feature_size: int,
        num_patches: Optional[int],
43
    ) -> PromptUpdateDetails[str]:
44
45
        if num_patches is None:
            raise NotImplementedError("Embedding inputs are not supported")
46

47
        tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)]
48
        if self.use_thumbnail:
49
            tile_pos_identifiers += ["<tile_global_thumbnail>"]
50

51
52
53
54
55
56
57
        context_size = feature_size // num_patches
        features = "".join(identifier + IMG_PAD * context_size
                           for identifier in tile_pos_identifiers)

        # We include the start and end as well because "<Image><tile" is
        # tokenized as ["<Image", "><", "tile"], resulting in assertion error
        # when trying to find "<tile" as a subsequence of "<Image><tile"
58
        repl = "<Image>" + features + "</Image>"
59

60
        return PromptUpdateDetails.select_text(repl, IMG_PAD)
61
62
63
64
65
66
67


class NVLMProcessingInfo(BaseInternVLProcessingInfo):

    def get_hf_processor(
        self,
        *,
68
        min_dynamic_patch: Optional[int] = None,
69
70
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
71
        **kwargs: object,
72
    ) -> NVLMProcessor:
73
74
75
76
77
78
79
80
81
82
83
84
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

        return self.ctx.init_processor(
            NVLMProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
85
86
87
        )


88
89
class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]
                             ):
90

91
92
93
94
95
96
97
98
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        # The newline is necessary to separate ">" of the current item
        # and "<" of the next item
        return "<image>\n" * num_images

    def get_dummy_mm_data(
99
100
101
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
102
    ) -> MultiModalDataDict:
103
104
105
106
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

107
        return {
108
109
110
111
112
113
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

114

115
116
class NVLMMultiModalProcessor(
        BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
117

118
    def _get_prompt_updates(
119
120
121
122
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
123
    ) -> Sequence[PromptUpdate]:
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
        else:
            image_num_patches = []

        def get_replacement_nvlm(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

155
156
            repl = hf_processor.get_image_repl(feature_size, num_patches)

157
            return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD)
158
159
160
161
162
163
164
165
166
167
168
169
170
171

        # See note in dummy data regarding why we have the extra newline
        return [
            PromptReplacement(
                modality="image",
                target="<image>\n",
                replacement=get_replacement_nvlm,
            )
        ]


@MULTIMODAL_REGISTRY.register_processor(NVLMMultiModalProcessor,
                                        info=NVLMProcessingInfo,
                                        dummy_inputs=NVLMDummyInputsBuilder)
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class NVLM_D_Model(InternVLChatModel):

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_intermediate_size = config.text_config.intermediate_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_intermediate_size,
                      bias=False),
            nn.GELU(),
            nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
        )

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1

            # We added additional dummy heads to the original num of heads to
            # make the number of heads divisible by 8.
            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                num_dummy_heads=7,
                prefix=prefix,
            )
        else:
            msg = "Monolith mode is not applicable to NVLM_D"
            raise NotImplementedError(msg)