visual.py 8.21 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

luopl's avatar
luopl committed
18
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
chenych's avatar
chenych committed
19
20

import torch
luopl's avatar
luopl committed
21
import transformers
chenych's avatar
chenych committed
22
23
24
import transformers.models
from transformers.activations import ACT2FN

luopl's avatar
luopl committed
25
from ...extras import logging
chenych's avatar
chenych committed
26
27
28


if TYPE_CHECKING:
luopl's avatar
luopl committed
29
    from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
chenych's avatar
chenych committed
30

luopl's avatar
luopl committed
31
    from ...hparams import FinetuningArguments, ModelArguments
chenych's avatar
chenych committed
32
33


luopl's avatar
luopl committed
34
35
logger = logging.get_logger(__name__)
transformers_logger = transformers.utils.logging.get_logger(__name__)
chenych's avatar
chenych committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
    def __init__(self, config: "LlavaConfig") -> None:
        super().__init__()

        self.config = config
        if config is None:
            return

        self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
        self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
        self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
        self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
        self.act = ACT2FN[config.projector_hidden_act]

    def forward(self, image_features: "torch.Tensor") -> "torch.Tensor":
        hidden_states = self.linear_1(image_features)
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_3(hidden_states)
        hidden_states = self.linear_4(hidden_states)
        if hidden_states.dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.linear_1.weight.dtype

            transformers_logger.warning_once("The hidden states seems to be silently casted in float32.")
            hidden_states = hidden_states.to(target_dtype)

        return hidden_states


class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
    def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str) -> None:
        super().__init__(config=None)

        self.linear_1 = torch.nn.Linear(vision_hidden_size, text_hidden_size, bias=True)
        self.linear_2 = torch.nn.LayerNorm(text_hidden_size, bias=True)
        self.linear_3 = torch.nn.Linear(text_hidden_size, text_hidden_size, bias=True)
        self.linear_4 = torch.nn.LayerNorm(text_hidden_size, bias=True)
        self.act = ACT2FN[projector_hidden_act]


luopl's avatar
luopl committed
83
84
85
86
87
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
    r"""
    Casts projector output to half precision for fine-tuning quantized VLMs.
    """

chenych's avatar
chenych committed
88
89
90
91
92
    def _mm_projector_forward_post_hook(
        module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
    ) -> "torch.Tensor":
        return output.to(model_args.compute_dtype)

luopl's avatar
luopl committed
93
94
    if getattr(model, "quantization_method", None):
        model_type = getattr(model.config, "model_type", None)
luopl's avatar
luopl committed
95
        if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
luopl's avatar
luopl committed
96
97
98
99
100
101
            mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
        elif model_type == "qwen2_vl":
            mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
        else:
            return

luopl's avatar
luopl committed
102
        logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
chenych's avatar
chenych committed
103
104
105
106
        mm_projector.register_forward_hook(_mm_projector_forward_post_hook)


def configure_visual_model(config: "PretrainedConfig") -> None:
luopl's avatar
luopl committed
107
108
109
110
111
112
113
114
115
    r"""
    Patches VLMs before loading them.
    """
    model_type = getattr(config, "model_type", None)
    if model_type in [
        "llava",
        "llava_next",
        "llava_next_video",
        "paligemma",
luopl's avatar
luopl committed
116
        "pixtral",
luopl's avatar
luopl committed
117
118
        "video_llava",
    ]:  # required for ds zero3 and valuehead models
chenych's avatar
chenych committed
119
120
121
        setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))

    if getattr(config, "is_yi_vl_derived_model", None):
luopl's avatar
luopl committed
122
        logger.info_rank0("Detected Yi-VL model, applying projector patch.")
chenych's avatar
chenych committed
123
        transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
luopl's avatar
luopl committed
124
125
126
127
128
129
130
131


def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
    r"""
    Freezes vision tower and language model for VLM full/freeze tuning.
    """
    model_type = getattr(config, "model_type", None)
    forbidden_modules = set()
luopl's avatar
luopl committed
132
    if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
luopl's avatar
luopl committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        if finetuning_args.freeze_vision_tower:
            forbidden_modules.add("vision_tower")

        if finetuning_args.train_mm_proj_only:
            forbidden_modules.add("language_model")

    elif model_type == "qwen2_vl":
        if finetuning_args.freeze_vision_tower:
            forbidden_modules.add("visual")

        if finetuning_args.train_mm_proj_only:
            raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")

    return forbidden_modules


def get_image_seqlen(config: "PretrainedConfig") -> int:
    r"""
    Computes the number of special tokens per image.
    """
    model_type = getattr(config, "model_type", None)
    if model_type == "llava":
        image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
        if getattr(config, "vision_feature_select_strategy", "default") == "full":  # add [CLS] token
            image_seqlen += 1
    elif model_type == "paligemma":
        image_seqlen = config.vision_config.num_image_tokens
    else:
        image_seqlen = -1

    return image_seqlen


luopl's avatar
luopl committed
166
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
luopl's avatar
luopl committed
167
168
169
    r"""
    Computes the patch size of the vit.
    """
luopl's avatar
luopl committed
170
    patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
luopl's avatar
luopl committed
171
172
173
    return patch_size


luopl's avatar
luopl committed
174
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
luopl's avatar
luopl committed
175
176
177
    r"""
    Get the vision_feature_select_strategy.
    """
luopl's avatar
luopl committed
178
179
180
    vision_feature_select_strategy = getattr(
        config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
    )
luopl's avatar
luopl committed
181
182
183
184
185
186
187
188
189
190
191
    return vision_feature_select_strategy


def patch_target_modules(
    config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]:
    r"""
    Freezes vision tower for VLM LoRA tuning.
    """
    model_type = getattr(config, "model_type", None)
    if finetuning_args.freeze_vision_tower:
luopl's avatar
luopl committed
192
        if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
luopl's avatar
luopl committed
193
            return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
luopl's avatar
luopl committed
194
195
        elif model_type == "mllama":
            return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules))
luopl's avatar
luopl committed
196
197
198
199
200
201
202
        elif model_type == "qwen2_vl":
            return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
        else:
            return target_modules
    else:
        if model_type == "qwen2_vl":
            return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
luopl's avatar
luopl committed
203
204
        elif model_type == "pixtral":
            return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
luopl's avatar
luopl committed
205
206
        else:
            return target_modules