loader.py 8.6 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

luopl's avatar
luopl committed
15
import os
chenych's avatar
chenych committed
16
from typing import TYPE_CHECKING, Any, Optional, TypedDict
chenych's avatar
chenych committed
17
18

import torch
chenych's avatar
chenych committed
19
20
21
22
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
chenych's avatar
chenych committed
23
    AutoModelForTextToWaveform,
chenych's avatar
chenych committed
24
25
26
27
    AutoModelForVision2Seq,
    AutoProcessor,
    AutoTokenizer,
)
chenych's avatar
chenych committed
28
29
from trl import AutoModelForCausalLMWithValueHead

luopl's avatar
luopl committed
30
31
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
chenych's avatar
chenych committed
32
from ..extras.packages import is_transformers_version_greater_than
chenych's avatar
chenych committed
33
from .adapter import init_adapter
luopl's avatar
luopl committed
34
from .model_utils.liger_kernel import apply_liger_kernel
chenych's avatar
chenych committed
35
36
37
38
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
luopl's avatar
luopl committed
39
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
chenych's avatar
chenych committed
40
41


chenych's avatar
chenych committed
42
43
44
45
if is_transformers_version_greater_than("4.46.0"):
    from transformers import AutoModelForImageTextToText


chenych's avatar
chenych committed
46
47
48
49
50
51
if TYPE_CHECKING:
    from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin

    from ..hparams import FinetuningArguments, ModelArguments


luopl's avatar
luopl committed
52
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
53
54
55
56
57
58
59


class TokenizerModule(TypedDict):
    tokenizer: "PreTrainedTokenizer"
    processor: Optional["ProcessorMixin"]


chenych's avatar
chenych committed
60
61
def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
    r"""Get arguments to load config/tokenizer/model.
chenych's avatar
chenych committed
62
63
64
65

    Note: including inplace operation of model_args.
    """
    skip_check_imports()
luopl's avatar
luopl committed
66
    model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
chenych's avatar
chenych committed
67
    return {
luopl's avatar
luopl committed
68
        "trust_remote_code": model_args.trust_remote_code,
chenych's avatar
chenych committed
69
70
71
72
73
74
75
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.hf_hub_token,
    }


def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
chenych's avatar
chenych committed
76
    r"""Load pretrained tokenizer and optionally loads processor.
chenych's avatar
chenych committed
77
78
79
80
81
82
83
84
85
86
87
88

    Note: including inplace operation of model_args.
    """
    init_kwargs = _get_init_kwargs(model_args)
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            use_fast=model_args.use_fast_tokenizer,
            split_special_tokens=model_args.split_special_tokens,
            padding_side="right",
            **init_kwargs,
        )
chenych's avatar
chenych committed
89
    except ValueError:  # try another one
chenych's avatar
chenych committed
90
91
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
chenych's avatar
chenych committed
92
            use_fast=not model_args.use_fast_tokenizer,
chenych's avatar
chenych committed
93
94
95
            padding_side="right",
            **init_kwargs,
        )
luopl's avatar
luopl committed
96
97
    except Exception as e:
        raise OSError("Failed to load tokenizer.") from e
chenych's avatar
chenych committed
98

chenych's avatar
chenych committed
99
    patch_tokenizer(tokenizer, model_args)
chenych's avatar
chenych committed
100

luopl's avatar
luopl committed
101
    try:
chenych's avatar
chenych committed
102
103
104
105
106
107
108
109
110
111
112
        processor = AutoProcessor.from_pretrained(
            model_args.model_name_or_path,
            use_fast=model_args.use_fast_tokenizer,
            **init_kwargs,
        )
    except ValueError:  # try another one
        processor = AutoProcessor.from_pretrained(
            model_args.model_name_or_path,
            use_fast=not model_args.use_fast_tokenizer,
            **init_kwargs,
        )
luopl's avatar
luopl committed
113
    except Exception as e:
chenych's avatar
chenych committed
114
115
116
        raise OSError("Failed to load processor.") from e

    patch_processor(processor, tokenizer, model_args)
chenych's avatar
chenych committed
117

luopl's avatar
luopl committed
118
119
120
    # Avoid load tokenizer, see:
    # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
    if processor is not None and "Processor" not in processor.__class__.__name__:
chenych's avatar
chenych committed
121
        logger.debug("The loaded processor is not an instance of Processor. Dropping it.")
chenych's avatar
chenych committed
122
123
124
125
126
127
        processor = None

    return {"tokenizer": tokenizer, "processor": processor}


def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
chenych's avatar
chenych committed
128
    r"""Load model config."""
chenych's avatar
chenych committed
129
130
131
132
133
134
135
136
137
138
139
    init_kwargs = _get_init_kwargs(model_args)
    return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)


def load_model(
    tokenizer: "PreTrainedTokenizer",
    model_args: "ModelArguments",
    finetuning_args: "FinetuningArguments",
    is_trainable: bool = False,
    add_valuehead: bool = False,
) -> "PreTrainedModel":
chenych's avatar
chenych committed
140
    r"""Load pretrained model."""
chenych's avatar
chenych committed
141
142
143
    init_kwargs = _get_init_kwargs(model_args)
    config = load_config(model_args)
    patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
luopl's avatar
luopl committed
144
    apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
chenych's avatar
chenych committed
145
146
147
148
149
150
151

    model = None
    lazy_load = False
    if model_args.use_unsloth:
        if model_args.adapter_name_or_path is not None:
            lazy_load = True
        elif is_trainable:
chenych's avatar
chenych committed
152
            model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
chenych's avatar
chenych committed
153
154
155
156
157
158
159
160

    if model is None and not lazy_load:
        init_kwargs["config"] = config
        init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path

        if model_args.mixture_of_depths == "load":
            model = load_mod_pretrained_model(**init_kwargs)
        else:
chenych's avatar
chenych committed
161
            if type(config) in AutoModelForVision2Seq._model_mapping.keys():  # image-text
luopl's avatar
luopl committed
162
                load_class = AutoModelForVision2Seq
chenych's avatar
chenych committed
163
164
165
166
            elif (
                is_transformers_version_greater_than("4.46.0")
                and type(config) in AutoModelForImageTextToText._model_mapping.keys()
            ):  # image-text
chenych's avatar
chenych committed
167
168
                load_class = AutoModelForImageTextToText
            elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():  # audio-text
chenych's avatar
chenych committed
169
                load_class = AutoModelForSeq2SeqLM
chenych's avatar
chenych committed
170
171
            elif type(config) in AutoModelForTextToWaveform._model_mapping.keys():  # audio hack for qwen2_5_omni
                load_class = AutoModelForTextToWaveform
luopl's avatar
luopl committed
172
173
            else:
                load_class = AutoModelForCausalLM
luopl's avatar
luopl committed
174

luopl's avatar
luopl committed
175
            if model_args.train_from_scratch:
luopl's avatar
luopl committed
176
                model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
luopl's avatar
luopl committed
177
178
            else:
                model = load_class.from_pretrained(**init_kwargs)
chenych's avatar
chenych committed
179
180
                if getattr(model.config, "model_type", None) == "qwen2_5_omni":
                    model = model.thinker  # use part of Omni model
chenych's avatar
chenych committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

        if model_args.mixture_of_depths == "convert":
            model = convert_pretrained_model_to_mod(model, config, model_args)

    if not lazy_load:
        patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
        register_autoclass(config, model, tokenizer)

    model = init_adapter(config, model, model_args, finetuning_args, is_trainable)

    if add_valuehead:
        model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
        patch_valuehead_model(model)

        if model_args.adapter_name_or_path is not None:
            vhead_path = model_args.adapter_name_or_path[-1]
        else:
            vhead_path = model_args.model_name_or_path

        vhead_params = load_valuehead_params(vhead_path, model_args)
        if vhead_params is not None:
            model.load_state_dict(vhead_params, strict=False)
luopl's avatar
luopl committed
203
            logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
chenych's avatar
chenych committed
204
205
206
207
208
209
210
211
212
213
214
215
216

    if not is_trainable:
        model.requires_grad_(False)
        for param in model.parameters():
            if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
                param.data = param.data.to(model_args.compute_dtype)

        model.eval()
    else:
        model.train()

    trainable_params, all_param = count_parameters(model)
    if is_trainable:
chenych's avatar
chenych committed
217
218
219
        param_stats = (
            f"trainable params: {trainable_params:,} || "
            f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
chenych's avatar
chenych committed
220
221
        )
    else:
luopl's avatar
luopl committed
222
        param_stats = f"all params: {all_param:,}"
chenych's avatar
chenych committed
223

luopl's avatar
luopl committed
224
    logger.info_rank0(param_stats)
chenych's avatar
chenych committed
225

luopl's avatar
luopl committed
226
    if model_args.print_param_status and int(os.getenv("LOCAL_RANK", "0")) == 0:
chenych's avatar
chenych committed
227
        for name, param in model.named_parameters():
luopl's avatar
luopl committed
228
            print(f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}")
chenych's avatar
chenych committed
229
230

    return model