tuner.py 8.18 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
chenych's avatar
chenych committed
17
from typing import TYPE_CHECKING, Any, Optional
chenych's avatar
chenych committed
18
19

import torch
chenych's avatar
chenych committed
20
import torch.distributed as dist
chenych's avatar
chenych committed
21
from transformers import EarlyStoppingCallback, PreTrainedModel
chenych's avatar
chenych committed
22
23

from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
24
from ..extras import logging
chenych's avatar
chenych committed
25
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
chenych's avatar
chenych committed
26
from ..extras.misc import infer_optim_dtype
luopl's avatar
luopl committed
27
28
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
chenych's avatar
chenych committed
29
from ..model import load_model, load_tokenizer
luopl's avatar
luopl committed
30
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
chenych's avatar
chenych committed
31
32
33
34
35
36
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
luopl's avatar
luopl committed
37
38
39
40
from .trainer_utils import get_ray_trainer, get_swanlab_callback


if is_ray_available():
chenych's avatar
chenych committed
41
    import ray
luopl's avatar
luopl committed
42
    from ray.train.huggingface.transformers import RayTrainReportCallback
chenych's avatar
chenych committed
43
44
45
46
47
48


if TYPE_CHECKING:
    from transformers import TrainerCallback


luopl's avatar
luopl committed
49
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
50
51


chenych's avatar
chenych committed
52
def _training_function(config: dict[str, Any]) -> None:
luopl's avatar
luopl committed
53
    args = config.get("args")
chenych's avatar
chenych committed
54
    callbacks: list[Any] = config.get("callbacks")
chenych's avatar
chenych committed
55
56
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)

luopl's avatar
luopl committed
57
58
59
60
61
62
63
    callbacks.append(LogCallback())
    if finetuning_args.pissa_convert:
        callbacks.append(PissaConvertCallback())

    if finetuning_args.use_swanlab:
        callbacks.append(get_swanlab_callback(finetuning_args))

chenych's avatar
chenych committed
64
65
66
    if finetuning_args.early_stopping_steps is not None:
        callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))

luopl's avatar
luopl committed
67
68
    callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args))  # add to last

chenych's avatar
chenych committed
69
70
71
72
73
74
75
76
77
78
79
80
81
    if finetuning_args.stage == "pt":
        run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "sft":
        run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
    elif finetuning_args.stage == "rm":
        run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "ppo":
        run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
    elif finetuning_args.stage == "dpo":
        run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "kto":
        run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
    else:
luopl's avatar
luopl committed
82
        raise ValueError(f"Unknown task: {finetuning_args.stage}.")
chenych's avatar
chenych committed
83

chenych's avatar
chenych committed
84
85
86
    if is_ray_available() and ray.is_initialized():
        return  # if ray is intialized it will destroy the process group on return

chenych's avatar
chenych committed
87
88
89
90
91
92
    try:
        if dist.is_initialized():
            dist.destroy_process_group()
    except Exception as e:
        logger.warning(f"Failed to destroy process group: {e}.")

chenych's avatar
chenych committed
93

chenych's avatar
chenych committed
94
def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
luopl's avatar
luopl committed
95
    args = read_args(args)
chenych's avatar
chenych committed
96
97
98
    if "-h" in args or "--help" in args:
        get_train_args(args)

luopl's avatar
luopl committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    ray_args = get_ray_args(args)
    callbacks = callbacks or []
    if ray_args.use_ray:
        callbacks.append(RayTrainReportCallback())
        trainer = get_ray_trainer(
            training_function=_training_function,
            train_loop_config={"args": args, "callbacks": callbacks},
            ray_args=ray_args,
        )
        trainer.fit()
    else:
        _training_function(config={"args": args, "callbacks": callbacks})


chenych's avatar
chenych committed
113
def export_model(args: Optional[dict[str, Any]] = None) -> None:
chenych's avatar
chenych committed
114
115
116
117
118
119
120
121
122
123
124
    model_args, data_args, finetuning_args, _ = get_infer_args(args)

    if model_args.export_dir is None:
        raise ValueError("Please specify `export_dir` to save model.")

    if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
        raise ValueError("Please merge adapters before quantizing the model.")

    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    processor = tokenizer_module["processor"]
chenych's avatar
chenych committed
125
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
chenych's avatar
chenych committed
126
127
128
129
130
131
132
133
134
135
136
137
    model = load_model(tokenizer, model_args, finetuning_args)  # must after fixing tokenizer to resize vocab

    if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
        raise ValueError("Cannot merge adapters to a quantized model.")

    if not isinstance(model, PreTrainedModel):
        raise ValueError("The model is not a `PreTrainedModel`, export aborted.")

    if getattr(model, "quantization_method", None) is not None:  # quantized model adopts float16 type
        setattr(model.config, "torch_dtype", torch.float16)
    else:
        if model_args.infer_dtype == "auto":
chenych's avatar
chenych committed
138
139
140
            output_dtype = getattr(model.config, "torch_dtype", torch.float32)
            if output_dtype == torch.float32:  # if infer_dtype is auto, try using half precision first
                output_dtype = infer_optim_dtype(torch.bfloat16)
chenych's avatar
chenych committed
141
142
143
144
145
        else:
            output_dtype = getattr(torch, model_args.infer_dtype)

        setattr(model.config, "torch_dtype", output_dtype)
        model = model.to(output_dtype)
luopl's avatar
luopl committed
146
        logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
chenych's avatar
chenych committed
147
148
149

    model.save_pretrained(
        save_directory=model_args.export_dir,
luopl's avatar
luopl committed
150
        max_shard_size=f"{model_args.export_size}GB",
chenych's avatar
chenych committed
151
152
153
154
155
156
        safe_serialization=(not model_args.export_legacy_format),
    )
    if model_args.export_hub_model_id is not None:
        model.push_to_hub(
            model_args.export_hub_model_id,
            token=model_args.hf_hub_token,
luopl's avatar
luopl committed
157
            max_shard_size=f"{model_args.export_size}GB",
chenych's avatar
chenych committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            safe_serialization=(not model_args.export_legacy_format),
        )

    if finetuning_args.stage == "rm":
        if model_args.adapter_name_or_path is not None:
            vhead_path = model_args.adapter_name_or_path[-1]
        else:
            vhead_path = model_args.model_name_or_path

        if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
            shutil.copy(
                os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
                os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
            )
luopl's avatar
luopl committed
172
            logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
chenych's avatar
chenych committed
173
174
175
176
177
        elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
            shutil.copy(
                os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
                os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
            )
luopl's avatar
luopl committed
178
            logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
chenych's avatar
chenych committed
179
180
181
182
183
184
185
186

    try:
        tokenizer.padding_side = "left"  # restore padding side
        tokenizer.init_kwargs["padding_side"] = "left"
        tokenizer.save_pretrained(model_args.export_dir)
        if model_args.export_hub_model_id is not None:
            tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)

luopl's avatar
luopl committed
187
        if processor is not None:
luopl's avatar
luopl committed
188
            processor.save_pretrained(model_args.export_dir)
chenych's avatar
chenych committed
189
            if model_args.export_hub_model_id is not None:
luopl's avatar
luopl committed
190
                processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
chenych's avatar
chenych committed
191

luopl's avatar
luopl committed
192
    except Exception as e:
luopl's avatar
luopl committed
193
        logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
chenych's avatar
chenych committed
194

chenych's avatar
chenych committed
195
196
    ollama_modelfile = os.path.join(model_args.export_dir, "Modelfile")
    with open(ollama_modelfile, "w", encoding="utf-8") as f:
chenych's avatar
chenych committed
197
        f.write(template.get_ollama_modelfile(tokenizer))
chenych's avatar
chenych committed
198
        logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")