patcher.py 11.7 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from types import MethodType
chenych's avatar
chenych committed
16
from typing import TYPE_CHECKING, Any
chenych's avatar
chenych committed
17
18
19

import torch
from peft import PeftModel
chenych's avatar
chenych committed
20
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
chenych's avatar
chenych committed
21
22
23
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled

luopl's avatar
luopl committed
24
from ..extras import logging
chenych's avatar
chenych committed
25
from ..extras.misc import infer_optim_dtype
luopl's avatar
luopl committed
26
from ..extras.packages import is_transformers_version_greater_than
chenych's avatar
chenych committed
27
28
29
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer
chenych's avatar
chenych committed
30
from .model_utils.kv_cache import configure_kv_cache
chenych's avatar
chenych committed
31
32
33
34
35
36
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
chenych's avatar
chenych committed
37
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
chenych's avatar
chenych committed
38
39
40


if TYPE_CHECKING:
luopl's avatar
luopl committed
41
    from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin
chenych's avatar
chenych committed
42
43
44
45
    from trl import AutoModelForCausalLMWithValueHead

    from ..hparams import ModelArguments

shihm's avatar
uodata  
shihm committed
46
47
48
if is_transformers_version_greater_than("4.57.0"):
    from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe

chenych's avatar
chenych committed
49

luopl's avatar
luopl committed
50
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
51
52


shihm's avatar
uodata  
shihm committed
53
54
55
56
57
58
59
60
61
62
63
def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
    if is_transformers_version_greater_than("4.57.0") and not is_transformers_version_greater_than("4.58.0"):
        from .model_utils.moe import Qwen3OmniMoeThinkerTextSparseMoeBlock

        logger.warning_rank0(
            "You are using transformers with 4.x version, the Qwen3OmniMoeThinkerTextSparseMoeBlock will have some issues about deepspeed zero2 and fsdp2 training, so that we patched this model to avoid it. Transformers v5.0.0rc0 has fixed the issue, you can also try to update the transformers to using qwen3_omni. See more information on https://github.com/hiyouga/LLaMA-Factory/issues/9628."
        )

        modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock


chenych's avatar
chenych committed
64
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
chenych's avatar
chenych committed
65
66
67
    if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
        tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)

chenych's avatar
chenych committed
68
69
    if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length:
        tokenizer.model_max_length = model_args.model_max_length  # enlarge the tokenizer max length
chenych's avatar
chenych committed
70

chenych's avatar
chenych committed
71
72
73
    if model_args.add_tokens is not None:
        num_added_tokens = tokenizer.add_tokens(new_tokens=model_args.add_tokens, special_tokens=False)
        logger.info_rank0("Add tokens {} to tokenizer's vocabulary.".format(",".join(model_args.add_tokens)))
chenych's avatar
chenych committed
74
75
76
77
        if num_added_tokens > 0 and not model_args.resize_vocab:
            model_args.resize_vocab = True
            logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")

chenych's avatar
chenych committed
78
79
80
81
82
83
84
85
86
    if model_args.add_special_tokens is not None:
        num_added_special_tokens = tokenizer.add_tokens(new_tokens=model_args.add_special_tokens, special_tokens=True)
        logger.info_rank0(
            "Add special tokens {} to tokenizer's vocabulary.".format(",".join(model_args.add_special_tokens))
        )
        if num_added_special_tokens > 0 and not model_args.resize_vocab:
            model_args.resize_vocab = True
            logger.warning_rank0("New special tokens have been added, changed `resize_vocab` to True.")

chenych's avatar
chenych committed
87

luopl's avatar
luopl committed
88
89
90
91
92
93
def patch_processor(
    processor: "ProcessorMixin",
    tokenizer: "PreTrainedTokenizer",
    model_args: "ModelArguments",
) -> None:
    setattr(processor, "tokenizer", tokenizer)
chenych's avatar
chenych committed
94
95
96
    setattr(processor, "image_max_pixels", model_args.image_max_pixels)
    setattr(processor, "image_min_pixels", model_args.image_min_pixels)
    setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
chenych's avatar
chenych committed
97
    setattr(processor, "crop_to_patches", model_args.crop_to_patches)
chenych's avatar
chenych committed
98
99
100
101
    setattr(processor, "video_max_pixels", model_args.video_max_pixels)
    setattr(processor, "video_min_pixels", model_args.video_min_pixels)
    setattr(processor, "video_fps", model_args.video_fps)
    setattr(processor, "video_maxlen", model_args.video_maxlen)
chenych's avatar
chenych committed
102
    setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
chenych's avatar
chenych committed
103
    setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
luopl's avatar
luopl committed
104
105


chenych's avatar
chenych committed
106
107
108
109
def patch_config(
    config: "PretrainedConfig",
    tokenizer: "PreTrainedTokenizer",
    model_args: "ModelArguments",
chenych's avatar
chenych committed
110
    init_kwargs: dict[str, Any],
chenych's avatar
chenych committed
111
112
113
114
115
116
117
118
    is_trainable: bool,
) -> None:
    if model_args.compute_dtype is None:  # priority: bf16 > fp16 > fp32
        if model_args.infer_dtype != "auto" and not is_trainable:
            model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
        else:
            model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))

chenych's avatar
chenych committed
119
120
    configure_attn_implementation(config, model_args)
    configure_rope(config, model_args)
chenych's avatar
chenych committed
121
    configure_longlora(config, model_args, is_trainable)
shihm's avatar
uodata  
shihm committed
122
    configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
chenych's avatar
chenych committed
123
124
    configure_moe(config, model_args, is_trainable)
    configure_visual_model(config)
luopl's avatar
luopl committed
125
    configure_packing(model_args, is_trainable)
chenych's avatar
chenych committed
126
    configure_kv_cache(config, model_args, is_trainable)
chenych's avatar
chenych committed
127
128
129
130
131
132

    if getattr(config, "model_type", None) == "qwen":
        setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
        for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
            setattr(config, dtype_name, model_args.compute_dtype == dtype)

luopl's avatar
luopl committed
133
    if getattr(config, "model_type", None) == "minicpmo":
chenych's avatar
chenych committed
134
        setattr(config, "init_audio", True)
luopl's avatar
luopl committed
135
136
        setattr(config, "init_tts", False)

chenych's avatar
chenych committed
137
138
139
140
    # replace the top-k gating method
    if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
        setattr(config.text_config, "topk_method", "greedy")

chenych's avatar
chenych committed
141
142
143
144
145
146
    if "InternVLChatModel" in getattr(config, "architectures", []):
        raise ValueError(
            "Please download the internvl models in a Hugging Face–compatible format "
            "(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
        )

luopl's avatar
luopl committed
147
148
149
    if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
        raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")

luopl's avatar
luopl committed
150
151
152
    if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
        raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")

shihm's avatar
uodata  
shihm committed
153
154
155
    if getattr(config, "model_type", None) == "qwen3_omni_moe":
        patch_qwen3_omni_moe_thinker_text_sparse_moe_block()

chenych's avatar
chenych committed
156
157
158
    # deepspeed zero3 is not compatible with low_cpu_mem_usage
    init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())

shihm's avatar
uodata  
shihm committed
159
160
161
162
    # fsdp/deepspeed zero3 does not need device map
    if not (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) and init_kwargs["low_cpu_mem_usage"]:
        if "device_map" not in init_kwargs and model_args.device_map:
            init_kwargs["device_map"] = model_args.device_map  # device map requires low_cpu_mem_usage=True
chenych's avatar
chenych committed
163

shihm's avatar
uodata  
shihm committed
164
165
        if init_kwargs.get("device_map", None) == "auto":
            init_kwargs["offload_folder"] = model_args.offload_folder
chenych's avatar
chenych committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182


def patch_model(
    model: "PreTrainedModel",
    tokenizer: "PreTrainedTokenizer",
    model_args: "ModelArguments",
    is_trainable: bool,
    add_valuehead: bool,
) -> None:
    gen_config = model.generation_config  # check and fix generation config
    if not gen_config.do_sample and (
        (gen_config.temperature is not None and gen_config.temperature != 1.0)
        or (gen_config.top_p is not None and gen_config.top_p != 1.0)
        or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
    ):
        gen_config.do_sample = True

luopl's avatar
luopl committed
183
184
185
    if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
        model.generate.__func__
    ):
chenych's avatar
chenych committed
186
        model.generate = MethodType(GenerationMixin.generate, model)
chenych's avatar
chenych committed
187
188
189
190
191

    if add_valuehead:
        prepare_valuehead_model(model)

    if model_args.resize_vocab:
shihm's avatar
uodata  
shihm committed
192
193
194
195
196
197
        resize_embedding_layer(
            model,
            tokenizer,
            new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
            init_special_tokens=model_args.init_special_tokens,
        )
chenych's avatar
chenych committed
198
199

    if is_trainable:
chenych's avatar
chenych committed
200
201
202
        if getattr(model.config, "model_type", None) == "gemma3n":
            setattr(model_args, "disable_gradient_checkpointing", True)

chenych's avatar
chenych committed
203
        prepare_model_for_training(model, model_args)
luopl's avatar
luopl committed
204
        autocast_projector_dtype(model, model_args)
chenych's avatar
chenych committed
205
206
207
208
209
210
211
212
        add_z3_leaf_module(model)

    if not model_args.use_unsloth:
        print_attn_implementation(model.config)

    try:
        model.add_model_tags(["llama-factory"])
    except Exception:
luopl's avatar
luopl committed
213
        logger.warning_rank0("Cannot properly tag the model.")
chenych's avatar
chenych committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232


def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
    def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
        if isinstance(self.pretrained_model, PreTrainedModel):
            self.pretrained_model.tie_weights()

    def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
        if isinstance(self.pretrained_model, PreTrainedModel):
            return self.pretrained_model.get_input_embeddings()

    def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
        if isinstance(self.pretrained_model, PreTrainedModel):
            return self.pretrained_model.get_output_embeddings()

    def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
        if isinstance(self.pretrained_model, PeftModel):
            self.pretrained_model.create_or_update_model_card(output_dir)

shihm's avatar
uodata  
shihm committed
233
234
235
236
237
238
239
240
241
242
243
244
245
    def get_rope_index_func(self: "AutoModelForCausalLMWithValueHead"):
        if isinstance(self.pretrained_model, PeftModel):
            base_model = self.pretrained_model.base_model.model
        else:
            base_model = self.pretrained_model

        if base_model and hasattr(base_model, "get_rope_index"):
            return base_model.get_rope_index
        elif base_model and hasattr(base_model, "model") and hasattr(base_model.model, "get_rope_index"):
            return base_model.model.get_rope_index
        else:
            return None

chenych's avatar
chenych committed
246
247
248
249
250
    ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
    setattr(model, "_keys_to_ignore_on_save", ignore_modules)
    setattr(model, "tie_weights", MethodType(tie_weights, model))
    setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
    setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
shihm's avatar
uodata  
shihm committed
251
    setattr(model, "get_rope_index", get_rope_index_func(model))
chenych's avatar
chenych committed
252
    setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))