nvlm_d.py 7.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py
# --------------------------------------------------------
# NVLM-D
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
10
from collections.abc import Mapping
11

12
import torch
13
14
15
import torch.nn as nn
from transformers import PretrainedConfig

16
from vllm.config.multimodal import BaseDummyOptions
17
from vllm.inputs import MultiModalDataDict
18
from vllm.model_executor.layers.quantization import QuantizationConfig
19
from vllm.multimodal import MULTIMODAL_REGISTRY
20
from vllm.multimodal.inputs import BatchedTensorInputs
21
22
23
24
25
26
27
28
29
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    PromptReplacement,
    PromptUpdateDetails,
)
30
31
from vllm.transformers_utils.processors.internvl import InternVLImageProcessor
from vllm.transformers_utils.processors.nvlm_d import NVLMProcessor
32
33

from .intern_vit import InternVisionModel
34
35
36
37
38
39
from .internvl import (
    BaseInternVLDummyInputsBuilder,
    BaseInternVLMultiModalProcessor,
    BaseInternVLProcessingInfo,
    InternVLChatModel,
)
40

41
42

class NVLMProcessingInfo(BaseInternVLProcessingInfo):
43
44
45
46
47
48
49
50
51
52
53
54
55
    def get_image_processor(self, **kwargs):
        config = self.get_hf_config()
        vision_config = config.vision_config

        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        kwargs.setdefault("image_size", vision_config.image_size)
        kwargs.setdefault("min_dynamic_patch", config.min_dynamic_patch)
        kwargs.setdefault("max_dynamic_patch", config.max_dynamic_patch)
        kwargs.setdefault("dynamic_image_size", config.dynamic_image_size)
        kwargs.setdefault("use_thumbnail", config.use_thumbnail)

        return InternVLImageProcessor(**kwargs)

56
    def get_hf_processor(self, **kwargs: object) -> NVLMProcessor:
57
58
59
60
61
62
63
64
65
66
        config = self.get_hf_config()
        vision_config = config.vision_config

        image_processor = self.get_image_processor(**kwargs)
        image_size = image_processor.image_size
        patch_size = vision_config.patch_size
        downsample_ratio = config.downsample_ratio
        image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2))

        return NVLMProcessor(
67
            tokenizer=self.get_tokenizer(),
68
69
            image_processor=image_processor,
            image_seq_length=image_seq_length,
70
71
72
        )


73
class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]):
74
75
76
77
78
79
80
81
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        # The newline is necessary to separate ">" of the current item
        # and "<" of the next item
        return "<image>\n" * num_images

    def get_dummy_mm_data(
82
83
84
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
85
        mm_options: Mapping[str, BaseDummyOptions],
86
    ) -> MultiModalDataDict:
87
        target_width, target_height = self.info.get_image_size_with_most_features()
88
89
        num_images = mm_counts.get("image", 0)

90
        image_overrides = mm_options.get("image")
91

92
        return {
93
94
95
96
97
98
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
99
100
        }

101

102
class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
103
    def _get_prompt_repl_image(
104
105
        self,
        mm_items: MultiModalDataItems,
106
107
108
        hf_processor: NVLMProcessor,
        out_mm_data: BatchedTensorInputs,
    ):
109
110
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
111
112
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
113
        elif "image_embeds" in out_mm_data:
114
115
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
116
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
117
118
119
120
121
        else:
            image_num_patches = []

        def get_replacement_nvlm(item_idx: int):
            images = mm_items.get_items(
122
123
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

139
            repl = hf_processor.get_image_repl(num_patches, num_features=feature_size)
140

141
142
143
            return PromptUpdateDetails.select_text(
                repl.full + "\n", hf_processor.ctx_image_token
            )
144
145

        # See note in dummy data regarding why we have the extra newline
146
147
148
149
150
        return PromptReplacement(
            modality="image",
            target="<image>\n",
            replacement=get_replacement_nvlm,
        )
151
152


153
154
155
156
157
@MULTIMODAL_REGISTRY.register_processor(
    NVLMMultiModalProcessor,
    info=NVLMProcessingInfo,
    dummy_inputs=NVLMDummyInputsBuilder,
)
158
class NVLM_D_Model(InternVLChatModel):
159
    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
160
161
162
163
164
        vit_hidden_size = config.vision_config.hidden_size
        llm_intermediate_size = config.text_config.intermediate_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
165
166
167
168
169
170
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            nn.Linear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
                llm_intermediate_size,
                bias=False,
            ),
171
172
173
174
            nn.GELU(),
            nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
        )

175
176
177
    def _init_vision_model(
        self,
        config: PretrainedConfig,
178
        quant_config: QuantizationConfig | None,
179
180
181
182
183
184
185
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
186
187
188
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
            else:
                num_hidden_layers = vision_feature_layer + 1

            # We added additional dummy heads to the original num of heads to
            # make the number of heads divisible by 8.
            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                num_dummy_heads=7,
                prefix=prefix,
            )
        else:
            msg = "Monolith mode is not applicable to NVLM_D"
            raise NotImplementedError(msg)