tuner.py 7.99 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
chenych's avatar
chenych committed
17
from typing import TYPE_CHECKING, Any, Optional
chenych's avatar
chenych committed
18
19

import torch
chenych's avatar
chenych committed
20
import torch.distributed as dist
chenych's avatar
chenych committed
21
22
23
from transformers import PreTrainedModel

from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
24
from ..extras import logging
chenych's avatar
chenych committed
25
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
chenych's avatar
chenych committed
26
from ..extras.misc import infer_optim_dtype
luopl's avatar
luopl committed
27
28
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
chenych's avatar
chenych committed
29
from ..model import load_model, load_tokenizer
luopl's avatar
luopl committed
30
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
chenych's avatar
chenych committed
31
32
33
34
35
36
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
luopl's avatar
luopl committed
37
38
39
40
from .trainer_utils import get_ray_trainer, get_swanlab_callback


if is_ray_available():
chenych's avatar
chenych committed
41
    import ray
luopl's avatar
luopl committed
42
    from ray.train.huggingface.transformers import RayTrainReportCallback
chenych's avatar
chenych committed
43
44
45
46
47
48


if TYPE_CHECKING:
    from transformers import TrainerCallback


luopl's avatar
luopl committed
49
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
50
51


chenych's avatar
chenych committed
52
def _training_function(config: dict[str, Any]) -> None:
luopl's avatar
luopl committed
53
    args = config.get("args")
chenych's avatar
chenych committed
54
    callbacks: list[Any] = config.get("callbacks")
chenych's avatar
chenych committed
55
56
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)

luopl's avatar
luopl committed
57
58
59
60
61
62
63
64
65
    callbacks.append(LogCallback())
    if finetuning_args.pissa_convert:
        callbacks.append(PissaConvertCallback())

    if finetuning_args.use_swanlab:
        callbacks.append(get_swanlab_callback(finetuning_args))

    callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args))  # add to last

chenych's avatar
chenych committed
66
67
68
69
70
71
72
73
74
75
76
77
78
    if finetuning_args.stage == "pt":
        run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "sft":
        run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
    elif finetuning_args.stage == "rm":
        run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "ppo":
        run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
    elif finetuning_args.stage == "dpo":
        run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "kto":
        run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
    else:
luopl's avatar
luopl committed
79
        raise ValueError(f"Unknown task: {finetuning_args.stage}.")
chenych's avatar
chenych committed
80

chenych's avatar
chenych committed
81
82
83
    if is_ray_available() and ray.is_initialized():
        return  # if ray is intialized it will destroy the process group on return

chenych's avatar
chenych committed
84
85
86
87
88
89
    try:
        if dist.is_initialized():
            dist.destroy_process_group()
    except Exception as e:
        logger.warning(f"Failed to destroy process group: {e}.")

chenych's avatar
chenych committed
90

chenych's avatar
chenych committed
91
def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
luopl's avatar
luopl committed
92
    args = read_args(args)
chenych's avatar
chenych committed
93
94
95
    if "-h" in args or "--help" in args:
        get_train_args(args)

luopl's avatar
luopl committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    ray_args = get_ray_args(args)
    callbacks = callbacks or []
    if ray_args.use_ray:
        callbacks.append(RayTrainReportCallback())
        trainer = get_ray_trainer(
            training_function=_training_function,
            train_loop_config={"args": args, "callbacks": callbacks},
            ray_args=ray_args,
        )
        trainer.fit()
    else:
        _training_function(config={"args": args, "callbacks": callbacks})


chenych's avatar
chenych committed
110
def export_model(args: Optional[dict[str, Any]] = None) -> None:
chenych's avatar
chenych committed
111
112
113
114
115
116
117
118
119
120
121
    model_args, data_args, finetuning_args, _ = get_infer_args(args)

    if model_args.export_dir is None:
        raise ValueError("Please specify `export_dir` to save model.")

    if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
        raise ValueError("Please merge adapters before quantizing the model.")

    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    processor = tokenizer_module["processor"]
chenych's avatar
chenych committed
122
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
chenych's avatar
chenych committed
123
124
125
126
127
128
129
130
131
132
133
134
    model = load_model(tokenizer, model_args, finetuning_args)  # must after fixing tokenizer to resize vocab

    if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
        raise ValueError("Cannot merge adapters to a quantized model.")

    if not isinstance(model, PreTrainedModel):
        raise ValueError("The model is not a `PreTrainedModel`, export aborted.")

    if getattr(model, "quantization_method", None) is not None:  # quantized model adopts float16 type
        setattr(model.config, "torch_dtype", torch.float16)
    else:
        if model_args.infer_dtype == "auto":
chenych's avatar
chenych committed
135
136
137
            output_dtype = getattr(model.config, "torch_dtype", torch.float32)
            if output_dtype == torch.float32:  # if infer_dtype is auto, try using half precision first
                output_dtype = infer_optim_dtype(torch.bfloat16)
chenych's avatar
chenych committed
138
139
140
141
142
        else:
            output_dtype = getattr(torch, model_args.infer_dtype)

        setattr(model.config, "torch_dtype", output_dtype)
        model = model.to(output_dtype)
luopl's avatar
luopl committed
143
        logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
chenych's avatar
chenych committed
144
145
146

    model.save_pretrained(
        save_directory=model_args.export_dir,
luopl's avatar
luopl committed
147
        max_shard_size=f"{model_args.export_size}GB",
chenych's avatar
chenych committed
148
149
150
151
152
153
        safe_serialization=(not model_args.export_legacy_format),
    )
    if model_args.export_hub_model_id is not None:
        model.push_to_hub(
            model_args.export_hub_model_id,
            token=model_args.hf_hub_token,
luopl's avatar
luopl committed
154
            max_shard_size=f"{model_args.export_size}GB",
chenych's avatar
chenych committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
            safe_serialization=(not model_args.export_legacy_format),
        )

    if finetuning_args.stage == "rm":
        if model_args.adapter_name_or_path is not None:
            vhead_path = model_args.adapter_name_or_path[-1]
        else:
            vhead_path = model_args.model_name_or_path

        if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
            shutil.copy(
                os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
                os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
            )
luopl's avatar
luopl committed
169
            logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
chenych's avatar
chenych committed
170
171
172
173
174
        elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
            shutil.copy(
                os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
                os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
            )
luopl's avatar
luopl committed
175
            logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
chenych's avatar
chenych committed
176
177
178
179
180
181
182
183

    try:
        tokenizer.padding_side = "left"  # restore padding side
        tokenizer.init_kwargs["padding_side"] = "left"
        tokenizer.save_pretrained(model_args.export_dir)
        if model_args.export_hub_model_id is not None:
            tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)

luopl's avatar
luopl committed
184
        if processor is not None:
luopl's avatar
luopl committed
185
            processor.save_pretrained(model_args.export_dir)
chenych's avatar
chenych committed
186
            if model_args.export_hub_model_id is not None:
luopl's avatar
luopl committed
187
                processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
chenych's avatar
chenych committed
188

luopl's avatar
luopl committed
189
    except Exception as e:
luopl's avatar
luopl committed
190
        logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
chenych's avatar
chenych committed
191

chenych's avatar
chenych committed
192
193
    ollama_modelfile = os.path.join(model_args.export_dir, "Modelfile")
    with open(ollama_modelfile, "w", encoding="utf-8") as f:
chenych's avatar
chenych committed
194
        f.write(template.get_ollama_modelfile(tokenizer))
chenych's avatar
chenych committed
195
        logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")