visual.py 12.2 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

luopl's avatar
luopl committed
18
from dataclasses import dataclass
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Optional
chenych's avatar
chenych committed
20
21

import torch
luopl's avatar
luopl committed
22
import transformers
chenych's avatar
chenych committed
23
24
25
import transformers.models
from transformers.activations import ACT2FN

luopl's avatar
luopl committed
26
from ...extras import logging
chenych's avatar
chenych committed
27
from ...extras.packages import is_transformers_version_greater_than
chenych's avatar
chenych committed
28
29
30


if TYPE_CHECKING:
chenych's avatar
chenych committed
31
    from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
chenych's avatar
chenych committed
32

luopl's avatar
luopl committed
33
    from ...hparams import FinetuningArguments, ModelArguments
chenych's avatar
chenych committed
34
35


luopl's avatar
luopl committed
36
37
logger = logging.get_logger(__name__)
transformers_logger = transformers.utils.logging.get_logger(__name__)
chenych's avatar
chenych committed
38
39


luopl's avatar
luopl committed
40
41
42
43
@dataclass
class CompositeModel:
    model_type: str
    projector_key: str
chenych's avatar
chenych committed
44
45
46
    vision_model_keys: list[str]
    language_model_keys: list[str]
    lora_conflict_keys: list[str]
luopl's avatar
luopl committed
47
48
49
50
51
52
53
54

    def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
        for key in self.projector_key.split("."):
            module = getattr(module, key)

        return module


chenych's avatar
chenych committed
55
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
luopl's avatar
luopl committed
56
57
58
59
60


def _register_composite_model(
    model_type: str,
    projector_key: Optional[str] = None,
chenych's avatar
chenych committed
61
62
63
    vision_model_keys: Optional[list[str]] = None,
    language_model_keys: Optional[list[str]] = None,
    lora_conflict_keys: Optional[list[str]] = None,
luopl's avatar
luopl committed
64
):
chenych's avatar
chenych committed
65
66
67
68
69
70
71
72
73
74
    r"""Register a new composite model.

    Args:
        model_type: model type
        projector_key: multi_modal_projector
        vision_model_keys: vision_tower
        language_model_keys: language_model
        lora_conflict_keys: None

    """
luopl's avatar
luopl committed
75
76
    COMPOSITE_MODELS[model_type] = CompositeModel(
        model_type=model_type,
chenych's avatar
chenych committed
77
78
        projector_key=projector_key or "multi_modal_projector",
        vision_model_keys=vision_model_keys or ["vision_tower"],
chenych's avatar
chenych committed
79
        language_model_keys=language_model_keys or ["language_model", "lm_head"],
chenych's avatar
chenych committed
80
        lora_conflict_keys=lora_conflict_keys or [],
luopl's avatar
luopl committed
81
82
83
    )


chenych's avatar
chenych committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
    def __init__(self, config: "LlavaConfig") -> None:
        super().__init__()

        self.config = config
        if config is None:
            return

        self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
        self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
        self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
        self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
        self.act = ACT2FN[config.projector_hidden_act]

    def forward(self, image_features: "torch.Tensor") -> "torch.Tensor":
        hidden_states = self.linear_1(image_features)
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_3(hidden_states)
        hidden_states = self.linear_4(hidden_states)
        if hidden_states.dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.linear_1.weight.dtype

            transformers_logger.warning_once("The hidden states seems to be silently casted in float32.")
            hidden_states = hidden_states.to(target_dtype)

        return hidden_states


class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
    def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str) -> None:
        super().__init__(config=None)

        self.linear_1 = torch.nn.Linear(vision_hidden_size, text_hidden_size, bias=True)
        self.linear_2 = torch.nn.LayerNorm(text_hidden_size, bias=True)
        self.linear_3 = torch.nn.Linear(text_hidden_size, text_hidden_size, bias=True)
        self.linear_4 = torch.nn.LayerNorm(text_hidden_size, bias=True)
        self.act = ACT2FN[projector_hidden_act]


luopl's avatar
luopl committed
129
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
chenych's avatar
chenych committed
130
    r"""Cast projector output to half precision for fine-tuning quantized VLMs."""
luopl's avatar
luopl committed
131

chenych's avatar
chenych committed
132
    def _mm_projector_forward_post_hook(
chenych's avatar
chenych committed
133
        module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
chenych's avatar
chenych committed
134
135
136
    ) -> "torch.Tensor":
        return output.to(model_args.compute_dtype)

luopl's avatar
luopl committed
137
138
    if getattr(model, "quantization_method", None):
        model_type = getattr(model.config, "model_type", None)
luopl's avatar
luopl committed
139
140
        if model_type in COMPOSITE_MODELS:
            mm_projector = COMPOSITE_MODELS[model_type].get_projector(model)
luopl's avatar
luopl committed
141
142
143
        else:
            return

luopl's avatar
luopl committed
144
        logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
chenych's avatar
chenych committed
145
146
147
148
        mm_projector.register_forward_hook(_mm_projector_forward_post_hook)


def configure_visual_model(config: "PretrainedConfig") -> None:
chenych's avatar
chenych committed
149
    r"""Patch VLMs before loading them."""
luopl's avatar
luopl committed
150
151
    if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
        # required for ds zero3 and valuehead models
chenych's avatar
chenych committed
152
153
154
        setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))

    if getattr(config, "is_yi_vl_derived_model", None):
luopl's avatar
luopl committed
155
        logger.info_rank0("Detected Yi-VL model, applying projector patch.")
chenych's avatar
chenych committed
156
        transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
luopl's avatar
luopl committed
157
158


chenych's avatar
chenych committed
159
160
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]:
    r"""Freeze vision tower and language model for VLM full/freeze tuning."""
luopl's avatar
luopl committed
161
162
    model_type = getattr(config, "model_type", None)
    forbidden_modules = set()
luopl's avatar
luopl committed
163
    if model_type in COMPOSITE_MODELS:
luopl's avatar
luopl committed
164
        if finetuning_args.freeze_vision_tower:
luopl's avatar
luopl committed
165
166
167
            vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys
            logger.info_rank0(f"Set vision model not trainable: {vision_model_keys}.")
            forbidden_modules.update(vision_model_keys)
luopl's avatar
luopl committed
168

luopl's avatar
luopl committed
169
170
171
172
        if finetuning_args.freeze_multi_modal_projector:
            projector_key = COMPOSITE_MODELS[model_type].projector_key
            logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
            forbidden_modules.add(projector_key)
luopl's avatar
luopl committed
173

chenych's avatar
chenych committed
174
        if finetuning_args.freeze_language_model:
luopl's avatar
luopl committed
175
176
177
            language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
            logger.info_rank0(f"Set language model not trainable: {language_model_keys}.")
            forbidden_modules.update(language_model_keys)
luopl's avatar
luopl committed
178
179
180
181
182

    return forbidden_modules


def patch_target_modules(
chenych's avatar
chenych committed
183
184
185
    model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: list[str]
) -> list[str]:
    r"""Freeze vision tower for VLM LoRA tuning."""
chenych's avatar
chenych committed
186
187
188
189
190
191
192
193
194
195
196
197
    model_type = getattr(model.config, "model_type", None)
    if model_type in COMPOSITE_MODELS:
        forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
        forbidden_modules.update(COMPOSITE_MODELS[model_type].lora_conflict_keys)
        module_names = []
        for name, _ in model.named_modules():
            if any(target_module in name for target_module in target_modules) and not any(
                forbidden_module in name for forbidden_module in forbidden_modules
            ):
                module_names.append(name)

        return module_names
luopl's avatar
luopl committed
198
    else:
chenych's avatar
chenych committed
199
        return target_modules
luopl's avatar
luopl committed
200
201


shihm's avatar
uodata  
shihm committed
202
203
204
205
206
207
208
209
210
_register_composite_model(
    model_type="dots_ocr",
    projector_key="vision_tower.merger",
    vision_model_keys=["vision_tower"],
    language_model_keys=["model", "lm_head"],
    lora_conflict_keys=["merger"],
)


chenych's avatar
chenych committed
211
_register_composite_model(
chenych's avatar
chenych committed
212
    model_type="gemma3",
chenych's avatar
chenych committed
213
214
215
)


chenych's avatar
chenych committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
_register_composite_model(
    model_type="gemma3n",
    vision_model_keys=["vision_tower", "audio_tower"],
    lora_conflict_keys=["timm_model", "subsample_conv_projection"],
)


# copied from qwen2vl
_register_composite_model(
    model_type="glm4v",
    projector_key="visual.merger",
    vision_model_keys=["visual.patch_embed", "visual.blocks"],
    language_model_keys=["language_model", "lm_head"],
    lora_conflict_keys=["patch_embed"],
)


shihm's avatar
uodata  
shihm committed
233
234
235
236
237
238
239
240
241
_register_composite_model(
    model_type="glm4v_moe",
    projector_key="visual.merger",
    vision_model_keys=["visual.patch_embed", "visual.blocks"],
    language_model_keys=["language_model", "lm_head"],
    lora_conflict_keys=["patch_embed"],
)


chenych's avatar
chenych committed
242
_register_composite_model(
chenych's avatar
chenych committed
243
    model_type="internvl",
chenych's avatar
chenych committed
244
245
)

shihm's avatar
uodata  
shihm committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
_register_composite_model(
    model_type="interns1",
)

_register_composite_model(
    model_type="Keye",
    projector_key="mlp_AR",
    vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
    language_model_keys=["model", "lm_head"],
    lora_conflict_keys=["patch_embedding"],
)


_register_composite_model(
    model_type="kimi_vl",
)

chenych's avatar
chenych committed
263
264
265
266
267
268
269

_register_composite_model(
    model_type="llama4",
    vision_model_keys=["vision_model"],
)


luopl's avatar
luopl committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
_register_composite_model(
    model_type="llava",
)


_register_composite_model(
    model_type="llava_next",
)


_register_composite_model(
    model_type="llava_next_video",
)


_register_composite_model(
    model_type="minicpmv",
chenych's avatar
chenych committed
287
    projector_key="resampler",
luopl's avatar
luopl committed
288
289
290
291
292
293
294
    vision_model_keys=["vpm"],
    language_model_keys=["llm"],
)


_register_composite_model(
    model_type="minicpmo",
chenych's avatar
chenych committed
295
296
    projector_key="resampler",
    vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
luopl's avatar
luopl committed
297
    language_model_keys=["llm"],
chenych's avatar
chenych committed
298
    lora_conflict_keys=["audio_projection_layer"],
luopl's avatar
luopl committed
299
300
)

shihm's avatar
uodata  
shihm committed
301

luopl's avatar
luopl committed
302
_register_composite_model(
chenych's avatar
chenych committed
303
    model_type="mistral3",
shihm's avatar
uodata  
shihm committed
304
    projector_key="model.multi_modal_projector",
luopl's avatar
luopl committed
305
306
307
308
)


_register_composite_model(
chenych's avatar
chenych committed
309
310
    model_type="mllama",
    vision_model_keys=["vision_model"],
luopl's avatar
luopl committed
311
312
313
314
)


_register_composite_model(
chenych's avatar
chenych committed
315
    model_type="paligemma",
luopl's avatar
luopl committed
316
317
318
)


chenych's avatar
chenych committed
319
320
321
322
323
324
_register_composite_model(
    model_type="qwen2_audio",
    vision_model_keys=["audio_tower"],
)


chenych's avatar
chenych committed
325
326
327
328
329
330
331
332
333
_register_composite_model(
    model_type="qwen2_5_omni_thinker",
    projector_key="visual.merger",
    vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
    language_model_keys=["model", "lm_head"],
    lora_conflict_keys=["patch_embed"],
)


luopl's avatar
luopl committed
334
335
336
337
_register_composite_model(
    model_type="qwen2_vl",
    projector_key="visual.merger",
    vision_model_keys=["visual.patch_embed", "visual.blocks"],
chenych's avatar
chenych committed
338
339
340
    language_model_keys=["language_model", "lm_head"]
    if is_transformers_version_greater_than("4.52.0")
    else ["model", "lm_head"],
chenych's avatar
chenych committed
341
342
343
344
345
346
347
348
    lora_conflict_keys=["patch_embed"],
)


_register_composite_model(
    model_type="qwen2_5_vl",
    projector_key="visual.merger",
    vision_model_keys=["visual.patch_embed", "visual.blocks"],
chenych's avatar
chenych committed
349
350
351
    language_model_keys=["language_model", "lm_head"]
    if is_transformers_version_greater_than("4.52.0")
    else ["model", "lm_head"],
chenych's avatar
chenych committed
352
    lora_conflict_keys=["patch_embed"],
luopl's avatar
luopl committed
353
)
chenych's avatar
chenych committed
354
355


shihm's avatar
uodata  
shihm committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
_register_composite_model(
    model_type="qwen3_vl",
    projector_key="visual.merger",
    vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
    language_model_keys=["language_model", "lm_head"],
    lora_conflict_keys=["patch_embed"],
)


_register_composite_model(
    model_type="qwen3_vl_moe",
    projector_key="visual.merger",
    vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
    language_model_keys=["language_model", "lm_head"],
    lora_conflict_keys=["patch_embed"],
)


_register_composite_model(
    model_type="qwen3_omni_moe_thinker",
    projector_key="visual.merger",
    vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
    language_model_keys=["model", "lm_head"],
    lora_conflict_keys=["patch_embed"],
)


chenych's avatar
chenych committed
383
384
385
_register_composite_model(
    model_type="video_llava",
)