callbacks.py 15.8 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
chenych's avatar
chenych committed
22
from typing import TYPE_CHECKING, Any, Optional
chenych's avatar
chenych committed
23
24
25
26
27
28
29
30
31
32
33

import torch
import transformers
from peft import PeftModel
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from transformers.utils import (
    SAFE_WEIGHTS_NAME,
    WEIGHTS_NAME,
    is_safetensors_available,
)
luopl's avatar
luopl committed
34
from typing_extensions import override
chenych's avatar
chenych committed
35

luopl's avatar
luopl committed
36
from ..extras import logging
chenych's avatar
chenych committed
37
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
chenych's avatar
chenych committed
38
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
chenych's avatar
chenych committed
39
40
41
42
43
44


if is_safetensors_available():
    from safetensors import safe_open
    from safetensors.torch import save_file

luopl's avatar
luopl committed
45

chenych's avatar
chenych committed
46
47
48
49
if TYPE_CHECKING:
    from transformers import TrainerControl, TrainerState, TrainingArguments
    from trl import AutoModelForCausalLMWithValueHead

luopl's avatar
luopl committed
50
51
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments

chenych's avatar
chenych committed
52

luopl's avatar
luopl committed
53
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
54
55
56
57
58


def fix_valuehead_checkpoint(
    model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
chenych's avatar
chenych committed
59
60
    r"""Fix the valuehead checkpoint files.

chenych's avatar
chenych committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    The model is already unwrapped.

    There are three cases:
    1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
    2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
    3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}

    We assume `stage3_gather_16bit_weights_on_model_save=true`.
    """
    if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
        return

    if safe_serialization:
        path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
        with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
chenych's avatar
chenych committed
76
            state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
chenych's avatar
chenych committed
77
78
    else:
        path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
chenych's avatar
chenych committed
79
        state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True)
chenych's avatar
chenych committed
80

luopl's avatar
luopl committed
81
82
    os.remove(path_to_checkpoint)
    decoder_state_dict, v_head_state_dict = {}, {}
chenych's avatar
chenych committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    for name, param in state_dict.items():
        if name.startswith("v_head."):
            v_head_state_dict[name] = param
        else:
            decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param

    model.pretrained_model.save_pretrained(
        output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
    )

    if safe_serialization:
        save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
    else:
        torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))

luopl's avatar
luopl committed
98
    logger.info_rank0(f"Value head model saved at: {output_dir}")
chenych's avatar
chenych committed
99
100
101


class FixValueHeadModelCallback(TrainerCallback):
chenych's avatar
chenych committed
102
    r"""A callback for fixing the checkpoint for valuehead models."""
luopl's avatar
luopl committed
103
104

    @override
chenych's avatar
chenych committed
105
106
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
107
            output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
chenych's avatar
chenych committed
108
            fix_valuehead_checkpoint(
luopl's avatar
luopl committed
109
                model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
chenych's avatar
chenych committed
110
111
112
113
            )


class SaveProcessorCallback(TrainerCallback):
chenych's avatar
chenych committed
114
    r"""A callback for saving the processor."""
luopl's avatar
luopl committed
115

chenych's avatar
chenych committed
116
117
118
    def __init__(self, processor: "ProcessorMixin") -> None:
        self.processor = processor

luopl's avatar
luopl committed
119
120
121
    @override
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
122
123
            output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
            self.processor.save_pretrained(output_dir)
luopl's avatar
luopl committed
124
125

    @override
chenych's avatar
chenych committed
126
127
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
128
            self.processor.save_pretrained(args.output_dir)
chenych's avatar
chenych committed
129
130
131


class PissaConvertCallback(TrainerCallback):
chenych's avatar
chenych committed
132
    r"""A callback for converting the PiSSA adapter to a normal one."""
chenych's avatar
chenych committed
133

luopl's avatar
luopl committed
134
    @override
chenych's avatar
chenych committed
135
136
137
138
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            model = kwargs.pop("model")
            pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
luopl's avatar
luopl committed
139
            logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
chenych's avatar
chenych committed
140
141
142
143
144
145
            if isinstance(model, PeftModel):
                init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
                setattr(model.peft_config["default"], "init_lora_weights", True)
                model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)

luopl's avatar
luopl committed
146
    @override
chenych's avatar
chenych committed
147
148
149
150
151
152
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            model = kwargs.pop("model")
            pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
            pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
            pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
luopl's avatar
luopl committed
153
            logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
chenych's avatar
chenych committed
154
155
156
157
158
159
160
161
162
163
            # 1. save a pissa backup with init_lora_weights: True
            # 2. save a converted lora with init_lora_weights: pissa
            # 3. load the pissa backup with init_lora_weights: True
            # 4. delete the initial adapter and change init_lora_weights to pissa
            if isinstance(model, PeftModel):
                init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
                setattr(model.peft_config["default"], "init_lora_weights", True)
                model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
                model.save_pretrained(
chenych's avatar
chenych committed
164
165
166
167
                    pissa_convert_dir,
                    safe_serialization=args.save_safetensors,
                    path_initial_model_for_weight_conversion=pissa_init_dir,
                )
chenych's avatar
chenych committed
168
169
170
171
172
173
                model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
                model.set_adapter("default")
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)


class LogCallback(TrainerCallback):
chenych's avatar
chenych committed
174
    r"""A callback for logging training and evaluation status."""
luopl's avatar
luopl committed
175

chenych's avatar
chenych committed
176
    def __init__(self) -> None:
luopl's avatar
luopl committed
177
        # Progress
chenych's avatar
chenych committed
178
179
180
181
182
        self.start_time = 0
        self.cur_steps = 0
        self.max_steps = 0
        self.elapsed_time = ""
        self.remaining_time = ""
chenych's avatar
chenych committed
183
        self.thread_pool: Optional[ThreadPoolExecutor] = None
luopl's avatar
luopl committed
184
        # Status
chenych's avatar
chenych committed
185
186
        self.aborted = False
        self.do_train = False
luopl's avatar
luopl committed
187
        # Web UI
chenych's avatar
chenych committed
188
        self.webui_mode = is_env_enabled("LLAMABOARD_ENABLED")
luopl's avatar
luopl committed
189
        if self.webui_mode and not use_ray():
chenych's avatar
chenych committed
190
            signal.signal(signal.SIGABRT, self._set_abort)
chenych's avatar
chenych committed
191
            self.logger_handler = logging.LoggerHandler(os.getenv("LLAMABOARD_WORKDIR"))
luopl's avatar
luopl committed
192
            logging.add_handler(self.logger_handler)
chenych's avatar
chenych committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            transformers.logging.add_handler(self.logger_handler)

    def _set_abort(self, signum, frame) -> None:
        self.aborted = True

    def _reset(self, max_steps: int = 0) -> None:
        self.start_time = time.time()
        self.cur_steps = 0
        self.max_steps = max_steps
        self.elapsed_time = ""
        self.remaining_time = ""

    def _timing(self, cur_steps: int) -> None:
        cur_time = time.time()
        elapsed_time = cur_time - self.start_time
        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
        remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
        self.cur_steps = cur_steps
        self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
        self.remaining_time = str(timedelta(seconds=int(remaining_time)))

chenych's avatar
chenych committed
214
    def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
chenych's avatar
chenych committed
215
216
217
218
219
220
221
222
223
224
225
226
        with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
            f.write(json.dumps(logs) + "\n")

    def _create_thread_pool(self, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        self.thread_pool = ThreadPoolExecutor(max_workers=1)

    def _close_thread_pool(self) -> None:
        if self.thread_pool is not None:
            self.thread_pool.shutdown(wait=True)
            self.thread_pool = None

luopl's avatar
luopl committed
227
    @override
chenych's avatar
chenych committed
228
229
230
231
232
233
    def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if (
            args.should_save
            and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
            and args.overwrite_output_dir
        ):
luopl's avatar
luopl committed
234
            logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
chenych's avatar
chenych committed
235
236
            os.remove(os.path.join(args.output_dir, TRAINER_LOG))

luopl's avatar
luopl committed
237
    @override
chenych's avatar
chenych committed
238
239
240
241
242
243
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            self.do_train = True
            self._reset(max_steps=state.max_steps)
            self._create_thread_pool(output_dir=args.output_dir)

luopl's avatar
luopl committed
244
    @override
chenych's avatar
chenych committed
245
246
247
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        self._close_thread_pool()

luopl's avatar
luopl committed
248
    @override
chenych's avatar
chenych committed
249
250
251
252
253
    def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if self.aborted:
            control.should_epoch_stop = True
            control.should_training_stop = True

luopl's avatar
luopl committed
254
    @override
chenych's avatar
chenych committed
255
256
257
258
259
    def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if self.aborted:
            control.should_epoch_stop = True
            control.should_training_stop = True

luopl's avatar
luopl committed
260
    @override
chenych's avatar
chenych committed
261
262
263
264
    def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not self.do_train:
            self._close_thread_pool()

luopl's avatar
luopl committed
265
    @override
chenych's avatar
chenych committed
266
267
268
269
    def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not self.do_train:
            self._close_thread_pool()

luopl's avatar
luopl committed
270
    @override
chenych's avatar
chenych committed
271
272
273
274
275
276
277
278
    def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not args.should_save:
            return

        self._timing(cur_steps=state.global_step)
        logs = dict(
            current_steps=self.cur_steps,
            total_steps=self.max_steps,
luopl's avatar
luopl committed
279
280
281
282
283
284
285
            loss=state.log_history[-1].get("loss"),
            eval_loss=state.log_history[-1].get("eval_loss"),
            predict_loss=state.log_history[-1].get("predict_loss"),
            reward=state.log_history[-1].get("reward"),
            accuracy=state.log_history[-1].get("rewards/accuracies"),
            lr=state.log_history[-1].get("learning_rate"),
            epoch=state.log_history[-1].get("epoch"),
chenych's avatar
chenych committed
286
287
288
289
            percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
            elapsed_time=self.elapsed_time,
            remaining_time=self.remaining_time,
        )
luopl's avatar
luopl committed
290
291
292
293
        if state.num_input_tokens_seen:
            logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2)
            logs["total_tokens"] = state.num_input_tokens_seen

chenych's avatar
chenych committed
294
        if is_env_enabled("RECORD_VRAM"):
luopl's avatar
luopl committed
295
            vram_allocated, vram_reserved = get_peak_memory()
luopl's avatar
luopl committed
296
297
            logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
            logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
luopl's avatar
luopl committed
298

chenych's avatar
chenych committed
299
        logs = {k: v for k, v in logs.items() if v is not None}
luopl's avatar
luopl committed
300
301
302
303
304
305
306
        if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
            log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
            for extra_key in ("reward", "accuracy", "throughput"):
                if logs.get(extra_key):
                    log_str += f", '{extra_key}': {logs[extra_key]:.2f}"

            logger.info_rank0("{" + log_str + "}")
chenych's avatar
chenych committed
307
308
309
310

        if self.thread_pool is not None:
            self.thread_pool.submit(self._write_log, args.output_dir, logs)

luopl's avatar
luopl committed
311
    @override
chenych's avatar
chenych committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    def on_prediction_step(
        self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
    ):
        if self.do_train:
            return

        if self.aborted:
            sys.exit(0)

        if not args.should_save:
            return

        eval_dataloader = kwargs.pop("eval_dataloader", None)
        if has_length(eval_dataloader):
            if self.max_steps == 0:
                self._reset(max_steps=len(eval_dataloader))
                self._create_thread_pool(output_dir=args.output_dir)

            self._timing(cur_steps=self.cur_steps + 1)
            if self.cur_steps % 5 == 0 and self.thread_pool is not None:
                logs = dict(
                    current_steps=self.cur_steps,
                    total_steps=self.max_steps,
                    percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
                    elapsed_time=self.elapsed_time,
                    remaining_time=self.remaining_time,
                )
                self.thread_pool.submit(self._write_log, args.output_dir, logs)
luopl's avatar
luopl committed
340
341
342


class ReporterCallback(TrainerCallback):
chenych's avatar
chenych committed
343
    r"""A callback for reporting training status to external logger."""
luopl's avatar
luopl committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385

    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
        self.model_args = model_args
        self.data_args = data_args
        self.finetuning_args = finetuning_args
        self.generating_args = generating_args
        os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")

    @override
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not state.is_world_process_zero:
            return

        if "wandb" in args.report_to:
            import wandb

            wandb.config.update(
                {
                    "model_args": self.model_args.to_dict(),
                    "data_args": self.data_args.to_dict(),
                    "finetuning_args": self.finetuning_args.to_dict(),
                    "generating_args": self.generating_args.to_dict(),
                }
            )

        if self.finetuning_args.use_swanlab:
            import swanlab  # type: ignore

            swanlab.config.update(
                {
                    "model_args": self.model_args.to_dict(),
                    "data_args": self.data_args.to_dict(),
                    "finetuning_args": self.finetuning_args.to_dict(),
                    "generating_args": self.generating_args.to_dict(),
                }
            )