callbacks.py 14.5 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional

import torch
import transformers
from peft import PeftModel
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from transformers.utils import (
    SAFE_WEIGHTS_NAME,
    WEIGHTS_NAME,
    is_safetensors_available,
)
luopl's avatar
luopl committed
34
from typing_extensions import override
chenych's avatar
chenych committed
35

luopl's avatar
luopl committed
36
from ..extras import logging
chenych's avatar
chenych committed
37
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
luopl's avatar
luopl committed
38
from ..extras.misc import get_peak_memory
chenych's avatar
chenych committed
39
40
41
42
43
44
45
46
47
48
49


if is_safetensors_available():
    from safetensors import safe_open
    from safetensors.torch import save_file

if TYPE_CHECKING:
    from transformers import TrainerControl, TrainerState, TrainingArguments
    from trl import AutoModelForCausalLMWithValueHead


luopl's avatar
luopl committed
50
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


def fix_valuehead_checkpoint(
    model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
    r"""
    The model is already unwrapped.

    There are three cases:
    1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
    2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
    3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}

    We assume `stage3_gather_16bit_weights_on_model_save=true`.
    """
    if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
        return

    if safe_serialization:
        path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
        with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
            state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
    else:
        path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
        state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")

luopl's avatar
luopl committed
77
78
    os.remove(path_to_checkpoint)
    decoder_state_dict, v_head_state_dict = {}, {}
chenych's avatar
chenych committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    for name, param in state_dict.items():
        if name.startswith("v_head."):
            v_head_state_dict[name] = param
        else:
            decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param

    model.pretrained_model.save_pretrained(
        output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
    )

    if safe_serialization:
        save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
    else:
        torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))

luopl's avatar
luopl committed
94
    logger.info_rank0(f"Value head model saved at: {output_dir}")
chenych's avatar
chenych committed
95
96
97


class FixValueHeadModelCallback(TrainerCallback):
luopl's avatar
luopl committed
98
99
100
101
102
    r"""
    A callback for fixing the checkpoint for valuehead models.
    """

    @override
chenych's avatar
chenych committed
103
104
105
106
107
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        r"""
        Event called after a checkpoint save.
        """
        if args.should_save:
luopl's avatar
luopl committed
108
            output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
chenych's avatar
chenych committed
109
            fix_valuehead_checkpoint(
luopl's avatar
luopl committed
110
                model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
chenych's avatar
chenych committed
111
112
113
114
            )


class SaveProcessorCallback(TrainerCallback):
luopl's avatar
luopl committed
115
116
117
118
    r"""
    A callback for saving the processor.
    """

chenych's avatar
chenych committed
119
120
121
    def __init__(self, processor: "ProcessorMixin") -> None:
        self.processor = processor

luopl's avatar
luopl committed
122
123
124
    @override
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
125
126
            output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
            self.processor.save_pretrained(output_dir)
luopl's avatar
luopl committed
127
128

    @override
chenych's avatar
chenych committed
129
130
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
luopl's avatar
luopl committed
131
            self.processor.save_pretrained(args.output_dir)
chenych's avatar
chenych committed
132
133
134
135


class PissaConvertCallback(TrainerCallback):
    r"""
luopl's avatar
luopl committed
136
    A callback for converting the PiSSA adapter to a normal one.
chenych's avatar
chenych committed
137
138
    """

luopl's avatar
luopl committed
139
    @override
chenych's avatar
chenych committed
140
141
142
143
144
145
146
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        r"""
        Event called at the beginning of training.
        """
        if args.should_save:
            model = kwargs.pop("model")
            pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
luopl's avatar
luopl committed
147
            logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
chenych's avatar
chenych committed
148
149
150
151
152
153
            if isinstance(model, PeftModel):
                init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
                setattr(model.peft_config["default"], "init_lora_weights", True)
                model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)

luopl's avatar
luopl committed
154
    @override
chenych's avatar
chenych committed
155
156
157
158
159
160
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            model = kwargs.pop("model")
            pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
            pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
            pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
luopl's avatar
luopl committed
161
            logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
chenych's avatar
chenych committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            # 1. save a pissa backup with init_lora_weights: True
            # 2. save a converted lora with init_lora_weights: pissa
            # 3. load the pissa backup with init_lora_weights: True
            # 4. delete the initial adapter and change init_lora_weights to pissa
            if isinstance(model, PeftModel):
                init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
                setattr(model.peft_config["default"], "init_lora_weights", True)
                model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
                model.save_pretrained(
                    pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
                )  # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
                model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
                model.set_adapter("default")
                if "pissa_init" in model.peft_config.keys():  # backward compatibility (peft<0.12.0)
                    model.delete_adapter("pissa_init")

                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)


class LogCallback(TrainerCallback):
luopl's avatar
luopl committed
183
184
185
186
    r"""
    A callback for logging training and evaluation status.
    """

chenych's avatar
chenych committed
187
    def __init__(self) -> None:
luopl's avatar
luopl committed
188
        # Progress
chenych's avatar
chenych committed
189
190
191
192
193
194
        self.start_time = 0
        self.cur_steps = 0
        self.max_steps = 0
        self.elapsed_time = ""
        self.remaining_time = ""
        self.thread_pool: Optional["ThreadPoolExecutor"] = None
luopl's avatar
luopl committed
195
        # Status
chenych's avatar
chenych committed
196
197
        self.aborted = False
        self.do_train = False
luopl's avatar
luopl committed
198
        # Web UI
chenych's avatar
chenych committed
199
200
201
        self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
        if self.webui_mode:
            signal.signal(signal.SIGABRT, self._set_abort)
luopl's avatar
luopl committed
202
203
            self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
            logging.add_handler(self.logger_handler)
chenych's avatar
chenych committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            transformers.logging.add_handler(self.logger_handler)

    def _set_abort(self, signum, frame) -> None:
        self.aborted = True

    def _reset(self, max_steps: int = 0) -> None:
        self.start_time = time.time()
        self.cur_steps = 0
        self.max_steps = max_steps
        self.elapsed_time = ""
        self.remaining_time = ""

    def _timing(self, cur_steps: int) -> None:
        cur_time = time.time()
        elapsed_time = cur_time - self.start_time
        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
        remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
        self.cur_steps = cur_steps
        self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
        self.remaining_time = str(timedelta(seconds=int(remaining_time)))

    def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
        with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
            f.write(json.dumps(logs) + "\n")

    def _create_thread_pool(self, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        self.thread_pool = ThreadPoolExecutor(max_workers=1)

    def _close_thread_pool(self) -> None:
        if self.thread_pool is not None:
            self.thread_pool.shutdown(wait=True)
            self.thread_pool = None

luopl's avatar
luopl committed
238
    @override
chenych's avatar
chenych committed
239
240
241
242
243
244
    def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if (
            args.should_save
            and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
            and args.overwrite_output_dir
        ):
luopl's avatar
luopl committed
245
            logger.warning_once("Previous trainer log in this folder will be deleted.")
chenych's avatar
chenych committed
246
247
            os.remove(os.path.join(args.output_dir, TRAINER_LOG))

luopl's avatar
luopl committed
248
    @override
chenych's avatar
chenych committed
249
250
251
252
253
254
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if args.should_save:
            self.do_train = True
            self._reset(max_steps=state.max_steps)
            self._create_thread_pool(output_dir=args.output_dir)

luopl's avatar
luopl committed
255
    @override
chenych's avatar
chenych committed
256
257
258
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        self._close_thread_pool()

luopl's avatar
luopl committed
259
    @override
chenych's avatar
chenych committed
260
261
262
263
264
    def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if self.aborted:
            control.should_epoch_stop = True
            control.should_training_stop = True

luopl's avatar
luopl committed
265
    @override
chenych's avatar
chenych committed
266
267
268
269
270
    def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if self.aborted:
            control.should_epoch_stop = True
            control.should_training_stop = True

luopl's avatar
luopl committed
271
    @override
chenych's avatar
chenych committed
272
273
274
275
    def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not self.do_train:
            self._close_thread_pool()

luopl's avatar
luopl committed
276
    @override
chenych's avatar
chenych committed
277
278
279
280
    def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not self.do_train:
            self._close_thread_pool()

luopl's avatar
luopl committed
281
    @override
chenych's avatar
chenych committed
282
283
284
285
286
287
288
289
    def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        if not args.should_save:
            return

        self._timing(cur_steps=state.global_step)
        logs = dict(
            current_steps=self.cur_steps,
            total_steps=self.max_steps,
luopl's avatar
luopl committed
290
291
292
293
294
295
296
            loss=state.log_history[-1].get("loss"),
            eval_loss=state.log_history[-1].get("eval_loss"),
            predict_loss=state.log_history[-1].get("predict_loss"),
            reward=state.log_history[-1].get("reward"),
            accuracy=state.log_history[-1].get("rewards/accuracies"),
            lr=state.log_history[-1].get("learning_rate"),
            epoch=state.log_history[-1].get("epoch"),
chenych's avatar
chenych committed
297
298
299
300
            percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
            elapsed_time=self.elapsed_time,
            remaining_time=self.remaining_time,
        )
luopl's avatar
luopl committed
301
302
303
304
305
306
        if state.num_input_tokens_seen:
            logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2)
            logs["total_tokens"] = state.num_input_tokens_seen

        if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
            vram_allocated, vram_reserved = get_peak_memory()
luopl's avatar
luopl committed
307
308
            logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
            logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
luopl's avatar
luopl committed
309

chenych's avatar
chenych committed
310
        logs = {k: v for k, v in logs.items() if v is not None}
luopl's avatar
luopl committed
311
312
313
314
315
316
317
        if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
            log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
            for extra_key in ("reward", "accuracy", "throughput"):
                if logs.get(extra_key):
                    log_str += f", '{extra_key}': {logs[extra_key]:.2f}"

            logger.info_rank0("{" + log_str + "}")
chenych's avatar
chenych committed
318
319
320
321

        if self.thread_pool is not None:
            self.thread_pool.submit(self._write_log, args.output_dir, logs)

luopl's avatar
luopl committed
322
    @override
chenych's avatar
chenych committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    def on_prediction_step(
        self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
    ):
        if self.do_train:
            return

        if self.aborted:
            sys.exit(0)

        if not args.should_save:
            return

        eval_dataloader = kwargs.pop("eval_dataloader", None)
        if has_length(eval_dataloader):
            if self.max_steps == 0:
                self._reset(max_steps=len(eval_dataloader))
                self._create_thread_pool(output_dir=args.output_dir)

            self._timing(cur_steps=self.cur_steps + 1)
            if self.cur_steps % 5 == 0 and self.thread_pool is not None:
                logs = dict(
                    current_steps=self.cur_steps,
                    total_steps=self.max_steps,
                    percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
                    elapsed_time=self.elapsed_time,
                    remaining_time=self.remaining_time,
                )
                self.thread_pool.submit(self._write_log, args.output_dir, logs)