callbacks.py 15.9 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
chenych's avatar
chenych committed
22
from typing import TYPE_CHECKING, Any, Optional
chenych's avatar
chenych committed
23
24
25
26
27
28

import torch
import transformers
from peft import PeftModel
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
shihm's avatar
uodata  
shihm committed
29
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
luopl's avatar
luopl committed
30
from typing_extensions import override
chenych's avatar
chenych committed
31

luopl's avatar
luopl committed
32
from ..extras import logging
chenych's avatar
chenych committed
33
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
chenych's avatar
chenych committed
34
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
shihm's avatar
uodata  
shihm committed
35
from ..extras.packages import is_safetensors_available
chenych's avatar
chenych committed
36
37
38
39
40
41


if is_safetensors_available():
    from safetensors import safe_open
    from safetensors.torch import save_file

luopl's avatar
luopl committed
42

chenych's avatar
chenych committed
43
44
45
46
if TYPE_CHECKING:
    from transformers import TrainerControl, TrainerState, TrainingArguments
    from trl import AutoModelForCausalLMWithValueHead

luopl's avatar
luopl committed
47
48
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments

chenych's avatar
chenych committed
49

luopl's avatar
luopl committed
50
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
51
52
53
54
55


def fix_valuehead_checkpoint(
    model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
chenych's avatar
chenych committed
56
57
    r"""Fix the valuehead checkpoint files.

chenych's avatar
chenych committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    The model is already unwrapped.

    There are three cases:
    1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
    2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
    3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}

    We assume `stage3_gather_16bit_weights_on_model_save=true`.
    """
    if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
        return

    if safe_serialization:
        path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
        with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
shihm's avatar
uodata  
shihm committed
73
            state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key).clone() for key in f.keys()}
chenych's avatar
chenych committed
74
75
    else:
        path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
chenych's avatar
chenych committed
76
        state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True)
chenych's avatar
chenych committed
77

luopl's avatar
luopl committed
78
79
    os.remove(path_to_checkpoint)
    decoder_state_dict, v_head_state_dict = {}, {}
chenych's avatar
chenych committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    for name, param in state_dict.items():
        if name.startswith("v_head."):
            v_head_state_dict[name] = param
        else:
            decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param

    model.pretrained_model.save_pretrained(
        output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
    )

    if safe_serialization:
        save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
    else:
        torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))

luopl's avatar
luopl committed
95
    logger.info_rank0(f"Value head model saved at: {output_dir}")
chenych's avatar
chenych committed
96
97
98


class FixValueHeadModelCallback(TrainerCallback):
chenych's avatar
chenych committed
99
    r"""A callback for fixing the checkpoint for valuehead models."""
luopl's avatar
luopl committed
100
101

    @override
chenych's avatar
chenych committed
102
103
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
104
            output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
chenych's avatar
chenych committed
105
            fix_valuehead_checkpoint(
luopl's avatar
luopl committed
106
                model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
chenych's avatar
chenych committed
107
108
109
110
            )


class SaveProcessorCallback(TrainerCallback):
chenych's avatar
chenych committed
111
    r"""A callback for saving the processor."""
luopl's avatar
luopl committed
112

chenych's avatar
chenych committed
113
114
115
    def __init__(self, processor: "ProcessorMixin") -> None:
        self.processor = processor

luopl's avatar
luopl committed
116
117
118
    @override
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
119
120
            output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
            self.processor.save_pretrained(output_dir)
luopl's avatar
luopl committed
121
122

    @override
chenych's avatar
chenych committed
123
124
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
125
            self.processor.save_pretrained(args.output_dir)
chenych's avatar
chenych committed
126
127
128


class PissaConvertCallback(TrainerCallback):
chenych's avatar
chenych committed
129
    r"""A callback for converting the PiSSA adapter to a normal one."""
chenych's avatar
chenych committed
130

luopl's avatar
luopl committed
131
    @override
chenych's avatar
chenych committed
132
133
134
135
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            model = kwargs.pop("model")
            pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
luopl's avatar
luopl committed
136
            logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
chenych's avatar
chenych committed
137
138
139
140
141
142
            if isinstance(model, PeftModel):
                init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
                setattr(model.peft_config["default"], "init_lora_weights", True)
                model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)

luopl's avatar
luopl committed
143
    @override
chenych's avatar
chenych committed
144
145
146
147
148
149
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            model = kwargs.pop("model")
            pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
            pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
            pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
luopl's avatar
luopl committed
150
            logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
chenych's avatar
chenych committed
151
152
153
154
155
156
157
158
159
160
            # 1. save a pissa backup with init_lora_weights: True
            # 2. save a converted lora with init_lora_weights: pissa
            # 3. load the pissa backup with init_lora_weights: True
            # 4. delete the initial adapter and change init_lora_weights to pissa
            if isinstance(model, PeftModel):
                init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
                setattr(model.peft_config["default"], "init_lora_weights", True)
                model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
                model.save_pretrained(
chenych's avatar
chenych committed
161
162
163
164
                    pissa_convert_dir,
                    safe_serialization=args.save_safetensors,
                    path_initial_model_for_weight_conversion=pissa_init_dir,
                )
chenych's avatar
chenych committed
165
166
167
168
169
170
                model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
                model.set_adapter("default")
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)


class LogCallback(TrainerCallback):
chenych's avatar
chenych committed
171
    r"""A callback for logging training and evaluation status."""
luopl's avatar
luopl committed
172

chenych's avatar
chenych committed
173
    def __init__(self) -> None:
luopl's avatar
luopl committed
174
        # Progress
chenych's avatar
chenych committed
175
176
177
178
179
        self.start_time = 0
        self.cur_steps = 0
        self.max_steps = 0
        self.elapsed_time = ""
        self.remaining_time = ""
chenych's avatar
chenych committed
180
        self.thread_pool: Optional[ThreadPoolExecutor] = None
luopl's avatar
luopl committed
181
        # Status
chenych's avatar
chenych committed
182
183
        self.aborted = False
        self.do_train = False
luopl's avatar
luopl committed
184
        # Web UI
chenych's avatar
chenych committed
185
        self.webui_mode = is_env_enabled("LLAMABOARD_ENABLED")
luopl's avatar
luopl committed
186
        if self.webui_mode and not use_ray():
chenych's avatar
chenych committed
187
            signal.signal(signal.SIGABRT, self._set_abort)
chenych's avatar
chenych committed
188
            self.logger_handler = logging.LoggerHandler(os.getenv("LLAMABOARD_WORKDIR"))
luopl's avatar
luopl committed
189
            logging.add_handler(self.logger_handler)
chenych's avatar
chenych committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
            transformers.logging.add_handler(self.logger_handler)

    def _set_abort(self, signum, frame) -> None:
        self.aborted = True

    def _reset(self, max_steps: int = 0) -> None:
        self.start_time = time.time()
        self.cur_steps = 0
        self.max_steps = max_steps
        self.elapsed_time = ""
        self.remaining_time = ""

    def _timing(self, cur_steps: int) -> None:
        cur_time = time.time()
        elapsed_time = cur_time - self.start_time
        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
        remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
        self.cur_steps = cur_steps
        self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
        self.remaining_time = str(timedelta(seconds=int(remaining_time)))

chenych's avatar
chenych committed
211
    def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
chenych's avatar
chenych committed
212
213
214
215
216
217
218
219
220
221
222
223
        with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
            f.write(json.dumps(logs) + "\n")

    def _create_thread_pool(self, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        self.thread_pool = ThreadPoolExecutor(max_workers=1)

    def _close_thread_pool(self) -> None:
        if self.thread_pool is not None:
            self.thread_pool.shutdown(wait=True)
            self.thread_pool = None

luopl's avatar
luopl committed
224
    @override
chenych's avatar
chenych committed
225
226
227
228
229
230
    def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if (
            args.should_save
            and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
            and args.overwrite_output_dir
        ):
luopl's avatar
luopl committed
231
            logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
chenych's avatar
chenych committed
232
233
            os.remove(os.path.join(args.output_dir, TRAINER_LOG))

luopl's avatar
luopl committed
234
    @override
chenych's avatar
chenych committed
235
236
237
238
239
240
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            self.do_train = True
            self._reset(max_steps=state.max_steps)
            self._create_thread_pool(output_dir=args.output_dir)

luopl's avatar
luopl committed
241
    @override
chenych's avatar
chenych committed
242
243
244
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        self._close_thread_pool()

luopl's avatar
luopl committed
245
    @override
chenych's avatar
chenych committed
246
247
248
249
250
    def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if self.aborted:
            control.should_epoch_stop = True
            control.should_training_stop = True

luopl's avatar
luopl committed
251
    @override
chenych's avatar
chenych committed
252
253
254
255
256
    def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if self.aborted:
            control.should_epoch_stop = True
            control.should_training_stop = True

luopl's avatar
luopl committed
257
    @override
chenych's avatar
chenych committed
258
259
260
261
    def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not self.do_train:
            self._close_thread_pool()

luopl's avatar
luopl committed
262
    @override
chenych's avatar
chenych committed
263
264
265
266
    def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not self.do_train:
            self._close_thread_pool()

luopl's avatar
luopl committed
267
    @override
chenych's avatar
chenych committed
268
269
270
271
272
273
274
275
    def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not args.should_save:
            return

        self._timing(cur_steps=state.global_step)
        logs = dict(
            current_steps=self.cur_steps,
            total_steps=self.max_steps,
luopl's avatar
luopl committed
276
277
278
279
280
281
282
            loss=state.log_history[-1].get("loss"),
            eval_loss=state.log_history[-1].get("eval_loss"),
            predict_loss=state.log_history[-1].get("predict_loss"),
            reward=state.log_history[-1].get("reward"),
            accuracy=state.log_history[-1].get("rewards/accuracies"),
            lr=state.log_history[-1].get("learning_rate"),
            epoch=state.log_history[-1].get("epoch"),
chenych's avatar
chenych committed
283
284
285
286
            percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
            elapsed_time=self.elapsed_time,
            remaining_time=self.remaining_time,
        )
luopl's avatar
luopl committed
287
288
289
290
        if state.num_input_tokens_seen:
            logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2)
            logs["total_tokens"] = state.num_input_tokens_seen

chenych's avatar
chenych committed
291
        if is_env_enabled("RECORD_VRAM"):
luopl's avatar
luopl committed
292
            vram_allocated, vram_reserved = get_peak_memory()
luopl's avatar
luopl committed
293
294
            logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
            logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
luopl's avatar
luopl committed
295

chenych's avatar
chenych committed
296
        logs = {k: v for k, v in logs.items() if v is not None}
luopl's avatar
luopl committed
297
298
299
300
301
302
303
        if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
            log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
            for extra_key in ("reward", "accuracy", "throughput"):
                if logs.get(extra_key):
                    log_str += f", '{extra_key}': {logs[extra_key]:.2f}"

            logger.info_rank0("{" + log_str + "}")
chenych's avatar
chenych committed
304
305
306
307

        if self.thread_pool is not None:
            self.thread_pool.submit(self._write_log, args.output_dir, logs)

luopl's avatar
luopl committed
308
    @override
chenych's avatar
chenych committed
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    def on_prediction_step(
        self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
    ):
        if self.do_train:
            return

        if self.aborted:
            sys.exit(0)

        if not args.should_save:
            return

        eval_dataloader = kwargs.pop("eval_dataloader", None)
        if has_length(eval_dataloader):
            if self.max_steps == 0:
                self._reset(max_steps=len(eval_dataloader))
                self._create_thread_pool(output_dir=args.output_dir)

            self._timing(cur_steps=self.cur_steps + 1)
            if self.cur_steps % 5 == 0 and self.thread_pool is not None:
                logs = dict(
                    current_steps=self.cur_steps,
                    total_steps=self.max_steps,
                    percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
                    elapsed_time=self.elapsed_time,
                    remaining_time=self.remaining_time,
                )
                self.thread_pool.submit(self._write_log, args.output_dir, logs)
luopl's avatar
luopl committed
337
338
339


class ReporterCallback(TrainerCallback):
chenych's avatar
chenych committed
340
    r"""A callback for reporting training status to external logger."""
luopl's avatar
luopl committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
        self.model_args = model_args
        self.data_args = data_args
        self.finetuning_args = finetuning_args
        self.generating_args = generating_args
        os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")

    @override
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not state.is_world_process_zero:
            return

        if "wandb" in args.report_to:
            import wandb

            wandb.config.update(
                {
                    "model_args": self.model_args.to_dict(),
                    "data_args": self.data_args.to_dict(),
                    "finetuning_args": self.finetuning_args.to_dict(),
                    "generating_args": self.generating_args.to_dict(),
                }
            )

        if self.finetuning_args.use_swanlab:
            import swanlab  # type: ignore

            swanlab.config.update(
                {
                    "model_args": self.model_args.to_dict(),
                    "data_args": self.data_args.to_dict(),
                    "finetuning_args": self.finetuning_args.to_dict(),
                    "generating_args": self.generating_args.to_dict(),
                }
            )