tuner.py 9.03 KB
Newer Older
shihm's avatar
uodata  
shihm committed
1
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
chenych's avatar
chenych committed
17
from typing import TYPE_CHECKING, Any, Optional
chenych's avatar
chenych committed
18
19

import torch
chenych's avatar
chenych committed
20
import torch.distributed as dist
chenych's avatar
chenych committed
21
from transformers import EarlyStoppingCallback, PreTrainedModel
chenych's avatar
chenych committed
22
23

from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
24
from ..extras import logging
chenych's avatar
chenych committed
25
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
chenych's avatar
chenych committed
26
from ..extras.misc import infer_optim_dtype
shihm's avatar
uodata  
shihm committed
27
from ..extras.packages import is_mcore_adapter_available, is_ray_available
luopl's avatar
luopl committed
28
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
chenych's avatar
chenych committed
29
from ..model import load_model, load_tokenizer
luopl's avatar
luopl committed
30
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
chenych's avatar
chenych committed
31
32
33
34
35
36
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
luopl's avatar
luopl committed
37
38
39
40
from .trainer_utils import get_ray_trainer, get_swanlab_callback


if is_ray_available():
chenych's avatar
chenych committed
41
    import ray
luopl's avatar
luopl committed
42
    from ray.train.huggingface.transformers import RayTrainReportCallback
chenych's avatar
chenych committed
43
44
45
46
47
48


if TYPE_CHECKING:
    from transformers import TrainerCallback


luopl's avatar
luopl committed
49
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
50
51


chenych's avatar
chenych committed
52
def _training_function(config: dict[str, Any]) -> None:
luopl's avatar
luopl committed
53
    args = config.get("args")
chenych's avatar
chenych committed
54
    callbacks: list[Any] = config.get("callbacks")
chenych's avatar
chenych committed
55
56
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)

luopl's avatar
luopl committed
57
58
59
60
61
62
63
    callbacks.append(LogCallback())
    if finetuning_args.pissa_convert:
        callbacks.append(PissaConvertCallback())

    if finetuning_args.use_swanlab:
        callbacks.append(get_swanlab_callback(finetuning_args))

chenych's avatar
chenych committed
64
65
66
    if finetuning_args.early_stopping_steps is not None:
        callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))

luopl's avatar
luopl committed
67
68
    callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args))  # add to last

shihm's avatar
uodata  
shihm committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
        if not is_mcore_adapter_available():
            raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
        if finetuning_args.stage == "pt":
            from .mca import run_pt as run_pt_mca

            run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
        elif finetuning_args.stage == "sft":
            from .mca import run_sft as run_sft_mca

            run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
        elif finetuning_args.stage == "dpo":
            from .mca import run_dpo as run_dpo_mca

            run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)

    elif finetuning_args.stage == "pt":
chenych's avatar
chenych committed
86
87
88
89
90
91
92
93
94
95
96
97
        run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "sft":
        run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
    elif finetuning_args.stage == "rm":
        run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "ppo":
        run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
    elif finetuning_args.stage == "dpo":
        run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "kto":
        run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
    else:
luopl's avatar
luopl committed
98
        raise ValueError(f"Unknown task: {finetuning_args.stage}.")
chenych's avatar
chenych committed
99

chenych's avatar
chenych committed
100
101
102
    if is_ray_available() and ray.is_initialized():
        return  # if ray is intialized it will destroy the process group on return

chenych's avatar
chenych committed
103
104
105
106
107
108
    try:
        if dist.is_initialized():
            dist.destroy_process_group()
    except Exception as e:
        logger.warning(f"Failed to destroy process group: {e}.")

chenych's avatar
chenych committed
109

chenych's avatar
chenych committed
110
def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
luopl's avatar
luopl committed
111
    args = read_args(args)
chenych's avatar
chenych committed
112
113
114
    if "-h" in args or "--help" in args:
        get_train_args(args)

luopl's avatar
luopl committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    ray_args = get_ray_args(args)
    callbacks = callbacks or []
    if ray_args.use_ray:
        callbacks.append(RayTrainReportCallback())
        trainer = get_ray_trainer(
            training_function=_training_function,
            train_loop_config={"args": args, "callbacks": callbacks},
            ray_args=ray_args,
        )
        trainer.fit()
    else:
        _training_function(config={"args": args, "callbacks": callbacks})


chenych's avatar
chenych committed
129
def export_model(args: Optional[dict[str, Any]] = None) -> None:
chenych's avatar
chenych committed
130
131
132
133
134
135
136
137
138
139
140
    model_args, data_args, finetuning_args, _ = get_infer_args(args)

    if model_args.export_dir is None:
        raise ValueError("Please specify `export_dir` to save model.")

    if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
        raise ValueError("Please merge adapters before quantizing the model.")

    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    processor = tokenizer_module["processor"]
chenych's avatar
chenych committed
141
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
chenych's avatar
chenych committed
142
143
144
145
146
147
148
149
150
151
152
153
    model = load_model(tokenizer, model_args, finetuning_args)  # must after fixing tokenizer to resize vocab

    if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
        raise ValueError("Cannot merge adapters to a quantized model.")

    if not isinstance(model, PreTrainedModel):
        raise ValueError("The model is not a `PreTrainedModel`, export aborted.")

    if getattr(model, "quantization_method", None) is not None:  # quantized model adopts float16 type
        setattr(model.config, "torch_dtype", torch.float16)
    else:
        if model_args.infer_dtype == "auto":
chenych's avatar
chenych committed
154
155
156
            output_dtype = getattr(model.config, "torch_dtype", torch.float32)
            if output_dtype == torch.float32:  # if infer_dtype is auto, try using half precision first
                output_dtype = infer_optim_dtype(torch.bfloat16)
chenych's avatar
chenych committed
157
158
159
160
161
        else:
            output_dtype = getattr(torch, model_args.infer_dtype)

        setattr(model.config, "torch_dtype", output_dtype)
        model = model.to(output_dtype)
luopl's avatar
luopl committed
162
        logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
chenych's avatar
chenych committed
163
164
165

    model.save_pretrained(
        save_directory=model_args.export_dir,
luopl's avatar
luopl committed
166
        max_shard_size=f"{model_args.export_size}GB",
chenych's avatar
chenych committed
167
168
169
170
171
172
        safe_serialization=(not model_args.export_legacy_format),
    )
    if model_args.export_hub_model_id is not None:
        model.push_to_hub(
            model_args.export_hub_model_id,
            token=model_args.hf_hub_token,
luopl's avatar
luopl committed
173
            max_shard_size=f"{model_args.export_size}GB",
chenych's avatar
chenych committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
            safe_serialization=(not model_args.export_legacy_format),
        )

    if finetuning_args.stage == "rm":
        if model_args.adapter_name_or_path is not None:
            vhead_path = model_args.adapter_name_or_path[-1]
        else:
            vhead_path = model_args.model_name_or_path

        if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
            shutil.copy(
                os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
                os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
            )
luopl's avatar
luopl committed
188
            logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
chenych's avatar
chenych committed
189
190
191
192
193
        elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
            shutil.copy(
                os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
                os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
            )
luopl's avatar
luopl committed
194
            logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
chenych's avatar
chenych committed
195
196
197
198
199
200
201
202

    try:
        tokenizer.padding_side = "left"  # restore padding side
        tokenizer.init_kwargs["padding_side"] = "left"
        tokenizer.save_pretrained(model_args.export_dir)
        if model_args.export_hub_model_id is not None:
            tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)

luopl's avatar
luopl committed
203
        if processor is not None:
luopl's avatar
luopl committed
204
            processor.save_pretrained(model_args.export_dir)
chenych's avatar
chenych committed
205
            if model_args.export_hub_model_id is not None:
luopl's avatar
luopl committed
206
                processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
chenych's avatar
chenych committed
207

luopl's avatar
luopl committed
208
    except Exception as e:
luopl's avatar
luopl committed
209
        logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
chenych's avatar
chenych committed
210

chenych's avatar
chenych committed
211
212
    ollama_modelfile = os.path.join(model_args.export_dir, "Modelfile")
    with open(ollama_modelfile, "w", encoding="utf-8") as f:
chenych's avatar
chenych committed
213
        f.write(template.get_ollama_modelfile(tokenizer))
chenych's avatar
chenych committed
214
        logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")