nvlm_d.py 7.45 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py
# --------------------------------------------------------
# NVLM-D
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
10
from collections.abc import Mapping, Sequence
11

12
import torch
13
14
15
import torch.nn as nn
from transformers import PretrainedConfig

16
from vllm.config.multimodal import BaseDummyOptions
17
from vllm.model_executor.layers.quantization import QuantizationConfig
18
from vllm.multimodal import MULTIMODAL_REGISTRY
19
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
20
21
22
23
24
25
26
27
28
29
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
30
31
from vllm.transformers_utils.processors.internvl import InternVLImageProcessor
from vllm.transformers_utils.processors.nvlm_d import NVLMProcessor
32
33

from .intern_vit import InternVisionModel
34
35
36
37
38
39
from .internvl import (
    BaseInternVLDummyInputsBuilder,
    BaseInternVLMultiModalProcessor,
    BaseInternVLProcessingInfo,
    InternVLChatModel,
)
40

41
42

class NVLMProcessingInfo(BaseInternVLProcessingInfo):
43
44
45
46
47
48
49
50
51
52
53
54
55
    def get_image_processor(self, **kwargs):
        config = self.get_hf_config()
        vision_config = config.vision_config

        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        kwargs.setdefault("image_size", vision_config.image_size)
        kwargs.setdefault("min_dynamic_patch", config.min_dynamic_patch)
        kwargs.setdefault("max_dynamic_patch", config.max_dynamic_patch)
        kwargs.setdefault("dynamic_image_size", config.dynamic_image_size)
        kwargs.setdefault("use_thumbnail", config.use_thumbnail)

        return InternVLImageProcessor(**kwargs)

56
    def get_hf_processor(self, **kwargs: object) -> NVLMProcessor:
57
58
59
60
61
62
63
64
65
66
        config = self.get_hf_config()
        vision_config = config.vision_config

        image_processor = self.get_image_processor(**kwargs)
        image_size = image_processor.image_size
        patch_size = vision_config.patch_size
        downsample_ratio = config.downsample_ratio
        image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2))

        return NVLMProcessor(
67
            tokenizer=self.get_tokenizer(),
68
69
            image_processor=image_processor,
            image_seq_length=image_seq_length,
70
71
72
        )


73
class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]):
74
75
76
77
78
79
80
81
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        # The newline is necessary to separate ">" of the current item
        # and "<" of the next item
        return "<image>\n" * num_images

    def get_dummy_mm_data(
82
83
84
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
85
        mm_options: Mapping[str, BaseDummyOptions],
86
    ) -> MultiModalDataDict:
87
        target_width, target_height = self.info.get_image_size_with_most_features()
88
89
        num_images = mm_counts.get("image", 0)

90
        image_overrides = mm_options.get("image")
91

92
        return {
93
94
95
96
97
98
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
99
100
        }

101

102
class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
103
    def _get_prompt_updates(
104
105
106
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
107
        out_mm_kwargs: MultiModalKwargsItems,
108
    ) -> Sequence[PromptUpdate]:
109
110
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

111
112
113
        out_mm_data = out_mm_kwargs.get_data()
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
114
115
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
116
        elif "image_embeds" in out_mm_data:
117
118
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
119
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
120
121
122
123
124
        else:
            image_num_patches = []

        def get_replacement_nvlm(item_idx: int):
            images = mm_items.get_items(
125
126
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

142
            repl = hf_processor.get_image_repl(num_patches, num_features=feature_size)
143

144
145
146
            return PromptUpdateDetails.select_text(
                repl.full + "\n", hf_processor.ctx_image_token
            )
147
148
149
150
151
152
153
154
155
156
157

        # See note in dummy data regarding why we have the extra newline
        return [
            PromptReplacement(
                modality="image",
                target="<image>\n",
                replacement=get_replacement_nvlm,
            )
        ]


158
159
160
161
162
@MULTIMODAL_REGISTRY.register_processor(
    NVLMMultiModalProcessor,
    info=NVLMProcessingInfo,
    dummy_inputs=NVLMDummyInputsBuilder,
)
163
class NVLM_D_Model(InternVLChatModel):
164
    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
165
166
167
168
169
        vit_hidden_size = config.vision_config.hidden_size
        llm_intermediate_size = config.text_config.intermediate_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
170
171
172
173
174
175
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            nn.Linear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
                llm_intermediate_size,
                bias=False,
            ),
176
177
178
179
            nn.GELU(),
            nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
        )

180
181
182
    def _init_vision_model(
        self,
        config: PretrainedConfig,
183
        quant_config: QuantizationConfig | None,
184
185
186
187
188
189
190
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
191
192
193
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            else:
                num_hidden_layers = vision_feature_layer + 1

            # We added additional dummy heads to the original num of heads to
            # make the number of heads divisible by 8.
            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                num_dummy_heads=7,
                prefix=prefix,
            )
        else:
            msg = "Monolith mode is not applicable to NVLM_D"
            raise NotImplementedError(msg)