callbacks.py 16.1 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional

import torch
import transformers
from peft import PeftModel
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from transformers.utils import (
    SAFE_WEIGHTS_NAME,
    WEIGHTS_NAME,
    is_safetensors_available,
)
luopl's avatar
luopl committed
34
from typing_extensions import override
chenych's avatar
chenych committed
35

luopl's avatar
luopl committed
36
from ..extras import logging
chenych's avatar
chenych committed
37
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
luopl's avatar
luopl committed
38
from ..extras.misc import get_peak_memory, use_ray
chenych's avatar
chenych committed
39
40
41
42
43
44


if is_safetensors_available():
    from safetensors import safe_open
    from safetensors.torch import save_file

luopl's avatar
luopl committed
45

chenych's avatar
chenych committed
46
47
48
49
if TYPE_CHECKING:
    from transformers import TrainerControl, TrainerState, TrainingArguments
    from trl import AutoModelForCausalLMWithValueHead

luopl's avatar
luopl committed
50
51
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments

chenych's avatar
chenych committed
52

luopl's avatar
luopl committed
53
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


def fix_valuehead_checkpoint(
    model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
    r"""
    The model is already unwrapped.

    There are three cases:
    1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
    2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
    3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}

    We assume `stage3_gather_16bit_weights_on_model_save=true`.
    """
    if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
        return

    if safe_serialization:
        path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
        with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
            state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
    else:
        path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
        state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")

luopl's avatar
luopl committed
80
81
    os.remove(path_to_checkpoint)
    decoder_state_dict, v_head_state_dict = {}, {}
chenych's avatar
chenych committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    for name, param in state_dict.items():
        if name.startswith("v_head."):
            v_head_state_dict[name] = param
        else:
            decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param

    model.pretrained_model.save_pretrained(
        output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
    )

    if safe_serialization:
        save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
    else:
        torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))

luopl's avatar
luopl committed
97
    logger.info_rank0(f"Value head model saved at: {output_dir}")
chenych's avatar
chenych committed
98
99
100


class FixValueHeadModelCallback(TrainerCallback):
luopl's avatar
luopl committed
101
102
103
104
105
    r"""
    A callback for fixing the checkpoint for valuehead models.
    """

    @override
chenych's avatar
chenych committed
106
107
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
108
            output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
chenych's avatar
chenych committed
109
            fix_valuehead_checkpoint(
luopl's avatar
luopl committed
110
                model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
chenych's avatar
chenych committed
111
112
113
114
            )


class SaveProcessorCallback(TrainerCallback):
luopl's avatar
luopl committed
115
116
117
118
    r"""
    A callback for saving the processor.
    """

chenych's avatar
chenych committed
119
120
121
    def __init__(self, processor: "ProcessorMixin") -> None:
        self.processor = processor

luopl's avatar
luopl committed
122
123
124
    @override
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
125
126
            output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
            self.processor.save_pretrained(output_dir)
luopl's avatar
luopl committed
127
128

    @override
chenych's avatar
chenych committed
129
130
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
131
            self.processor.save_pretrained(args.output_dir)
chenych's avatar
chenych committed
132
133
134
135


class PissaConvertCallback(TrainerCallback):
    r"""
luopl's avatar
luopl committed
136
    A callback for converting the PiSSA adapter to a normal one.
chenych's avatar
chenych committed
137
138
    """

luopl's avatar
luopl committed
139
    @override
chenych's avatar
chenych committed
140
141
142
143
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            model = kwargs.pop("model")
            pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
luopl's avatar
luopl committed
144
            logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
chenych's avatar
chenych committed
145
146
147
148
149
150
            if isinstance(model, PeftModel):
                init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
                setattr(model.peft_config["default"], "init_lora_weights", True)
                model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)

luopl's avatar
luopl committed
151
    @override
chenych's avatar
chenych committed
152
153
154
155
156
157
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            model = kwargs.pop("model")
            pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
            pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
            pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
luopl's avatar
luopl committed
158
            logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
chenych's avatar
chenych committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            # 1. save a pissa backup with init_lora_weights: True
            # 2. save a converted lora with init_lora_weights: pissa
            # 3. load the pissa backup with init_lora_weights: True
            # 4. delete the initial adapter and change init_lora_weights to pissa
            if isinstance(model, PeftModel):
                init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
                setattr(model.peft_config["default"], "init_lora_weights", True)
                model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
                model.save_pretrained(
                    pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
                )  # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
                model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
                model.set_adapter("default")
                if "pissa_init" in model.peft_config.keys():  # backward compatibility (peft<0.12.0)
                    model.delete_adapter("pissa_init")

                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)


class LogCallback(TrainerCallback):
luopl's avatar
luopl committed
180
181
182
183
    r"""
    A callback for logging training and evaluation status.
    """

chenych's avatar
chenych committed
184
    def __init__(self) -> None:
luopl's avatar
luopl committed
185
        # Progress
chenych's avatar
chenych committed
186
187
188
189
190
191
        self.start_time = 0
        self.cur_steps = 0
        self.max_steps = 0
        self.elapsed_time = ""
        self.remaining_time = ""
        self.thread_pool: Optional["ThreadPoolExecutor"] = None
luopl's avatar
luopl committed
192
        # Status
chenych's avatar
chenych committed
193
194
        self.aborted = False
        self.do_train = False
luopl's avatar
luopl committed
195
        # Web UI
chenych's avatar
chenych committed
196
        self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
luopl's avatar
luopl committed
197
        if self.webui_mode and not use_ray():
chenych's avatar
chenych committed
198
            signal.signal(signal.SIGABRT, self._set_abort)
luopl's avatar
luopl committed
199
200
            self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
            logging.add_handler(self.logger_handler)
chenych's avatar
chenych committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
            transformers.logging.add_handler(self.logger_handler)

    def _set_abort(self, signum, frame) -> None:
        self.aborted = True

    def _reset(self, max_steps: int = 0) -> None:
        self.start_time = time.time()
        self.cur_steps = 0
        self.max_steps = max_steps
        self.elapsed_time = ""
        self.remaining_time = ""

    def _timing(self, cur_steps: int) -> None:
        cur_time = time.time()
        elapsed_time = cur_time - self.start_time
        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
        remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
        self.cur_steps = cur_steps
        self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
        self.remaining_time = str(timedelta(seconds=int(remaining_time)))

    def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
        with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
            f.write(json.dumps(logs) + "\n")

    def _create_thread_pool(self, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        self.thread_pool = ThreadPoolExecutor(max_workers=1)

    def _close_thread_pool(self) -> None:
        if self.thread_pool is not None:
            self.thread_pool.shutdown(wait=True)
            self.thread_pool = None

luopl's avatar
luopl committed
235
    @override
chenych's avatar
chenych committed
236
237
238
239
240
241
    def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if (
            args.should_save
            and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
            and args.overwrite_output_dir
        ):
luopl's avatar
luopl committed
242
            logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
chenych's avatar
chenych committed
243
244
            os.remove(os.path.join(args.output_dir, TRAINER_LOG))

luopl's avatar
luopl committed
245
    @override
chenych's avatar
chenych committed
246
247
248
249
250
251
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            self.do_train = True
            self._reset(max_steps=state.max_steps)
            self._create_thread_pool(output_dir=args.output_dir)

luopl's avatar
luopl committed
252
    @override
chenych's avatar
chenych committed
253
254
255
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        self._close_thread_pool()

luopl's avatar
luopl committed
256
    @override
chenych's avatar
chenych committed
257
258
259
260
261
    def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if self.aborted:
            control.should_epoch_stop = True
            control.should_training_stop = True

luopl's avatar
luopl committed
262
    @override
chenych's avatar
chenych committed
263
264
265
266
267
    def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if self.aborted:
            control.should_epoch_stop = True
            control.should_training_stop = True

luopl's avatar
luopl committed
268
    @override
chenych's avatar
chenych committed
269
270
271
272
    def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not self.do_train:
            self._close_thread_pool()

luopl's avatar
luopl committed
273
    @override
chenych's avatar
chenych committed
274
275
276
277
    def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not self.do_train:
            self._close_thread_pool()

luopl's avatar
luopl committed
278
    @override
chenych's avatar
chenych committed
279
280
281
282
283
284
285
286
    def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not args.should_save:
            return

        self._timing(cur_steps=state.global_step)
        logs = dict(
            current_steps=self.cur_steps,
            total_steps=self.max_steps,
luopl's avatar
luopl committed
287
288
289
290
291
292
293
            loss=state.log_history[-1].get("loss"),
            eval_loss=state.log_history[-1].get("eval_loss"),
            predict_loss=state.log_history[-1].get("predict_loss"),
            reward=state.log_history[-1].get("reward"),
            accuracy=state.log_history[-1].get("rewards/accuracies"),
            lr=state.log_history[-1].get("learning_rate"),
            epoch=state.log_history[-1].get("epoch"),
chenych's avatar
chenych committed
294
295
296
297
            percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
            elapsed_time=self.elapsed_time,
            remaining_time=self.remaining_time,
        )
luopl's avatar
luopl committed
298
299
300
301
302
303
        if state.num_input_tokens_seen:
            logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2)
            logs["total_tokens"] = state.num_input_tokens_seen

        if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
            vram_allocated, vram_reserved = get_peak_memory()
luopl's avatar
luopl committed
304
305
            logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
            logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
luopl's avatar
luopl committed
306

chenych's avatar
chenych committed
307
        logs = {k: v for k, v in logs.items() if v is not None}
luopl's avatar
luopl committed
308
309
310
311
312
313
314
        if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
            log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
            for extra_key in ("reward", "accuracy", "throughput"):
                if logs.get(extra_key):
                    log_str += f", '{extra_key}': {logs[extra_key]:.2f}"

            logger.info_rank0("{" + log_str + "}")
chenych's avatar
chenych committed
315
316
317
318

        if self.thread_pool is not None:
            self.thread_pool.submit(self._write_log, args.output_dir, logs)

luopl's avatar
luopl committed
319
    @override
chenych's avatar
chenych committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    def on_prediction_step(
        self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
    ):
        if self.do_train:
            return

        if self.aborted:
            sys.exit(0)

        if not args.should_save:
            return

        eval_dataloader = kwargs.pop("eval_dataloader", None)
        if has_length(eval_dataloader):
            if self.max_steps == 0:
                self._reset(max_steps=len(eval_dataloader))
                self._create_thread_pool(output_dir=args.output_dir)

            self._timing(cur_steps=self.cur_steps + 1)
            if self.cur_steps % 5 == 0 and self.thread_pool is not None:
                logs = dict(
                    current_steps=self.cur_steps,
                    total_steps=self.max_steps,
                    percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
                    elapsed_time=self.elapsed_time,
                    remaining_time=self.remaining_time,
                )
                self.thread_pool.submit(self._write_log, args.output_dir, logs)
luopl's avatar
luopl committed
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395


class ReporterCallback(TrainerCallback):
    r"""
    A callback for reporting training status to external logger.
    """

    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
        self.model_args = model_args
        self.data_args = data_args
        self.finetuning_args = finetuning_args
        self.generating_args = generating_args
        os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")

    @override
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not state.is_world_process_zero:
            return

        if "wandb" in args.report_to:
            import wandb

            wandb.config.update(
                {
                    "model_args": self.model_args.to_dict(),
                    "data_args": self.data_args.to_dict(),
                    "finetuning_args": self.finetuning_args.to_dict(),
                    "generating_args": self.generating_args.to_dict(),
                }
            )

        if self.finetuning_args.use_swanlab:
            import swanlab  # type: ignore

            swanlab.config.update(
                {
                    "model_args": self.model_args.to_dict(),
                    "data_args": self.data_args.to_dict(),
                    "finetuning_args": self.finetuning_args.to_dict(),
                    "generating_args": self.generating_args.to_dict(),
                }
            )