callback.py 4.35 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import time

import json
from tqdm.auto import tqdm
from transformers.trainer_callback import (DefaultFlowCallback, ProgressCallback, TrainerCallback, TrainerControl,
                                           TrainerState)
from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics

from swift.utils import is_pai_training_job, use_torchacc
from .arguments import TrainingArguments


class ProgressCallbackNew(ProgressCallback):

    def on_train_begin(self, args, state, control, **kwargs):
        if state.is_local_process_zero:
            self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True)
        self.current_step = 0
        if use_torchacc():
            self.warmup_start_time = 0
            self.warmup_metric = None
            self.metric_warmup_step = int(args.metric_warmup_step
                                          * state.max_steps) if args.metric_warmup_step < 1 else args.metric_warmup_step

    def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader=None, **kwargs):
        if state.is_local_process_zero and has_length(eval_dataloader):
            if self.prediction_bar is None:
                if self.training_bar is not None:
                    self.training_bar.fp.write('\n')
                self.prediction_bar = tqdm(
                    desc='Val', total=len(eval_dataloader), leave=True, dynamic_ncols=True, position=0)
            self.prediction_bar.update()

    def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
        logs['global_step'] = state.global_step
        if use_torchacc():
            if state.global_step >= self.metric_warmup_step and self.warmup_start_time == 0:
                self.warmup_start_time = time.time()
                self.metric_warmup_step = state.global_step
            if state.max_steps == state.global_step and self.warmup_metric is None:
                num_steps = state.max_steps - self.metric_warmup_step
                num_total_samples = args.train_dataset_sample
                num_after_warmup_samples = int(num_total_samples / state.max_steps * num_steps)
                self.warmup_metric = speed_metrics('warmup_train', self.warmup_start_time, num_after_warmup_samples,
                                                   num_steps)
                self.warmup_metric['num_total_samples'] = num_total_samples
                self.warmup_metric['num_after_warmup_samples'] = num_after_warmup_samples
            if 'train_samples_per_second' in logs:
                logs.update(self.warmup_metric)
                state.log_history[-1] = logs
        for k, v in logs.items():
            if isinstance(v, float):
                logs[k] = round(logs[k], 8)
        if not is_pai_training_job() and state.is_local_process_zero:
            jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
            with open(jsonl_path, 'a', encoding='utf-8') as f:
                f.write(json.dumps(logs) + '\n')
        super().on_log(args, state, control, logs, **kwargs)
        if state.is_local_process_zero and self.training_bar is not None:
            self.training_bar.refresh()


class DefaultFlowCallbackNew(DefaultFlowCallback):

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        control = super().on_step_end(args, state, control, **kwargs)
        # save the last ckpt
        if state.global_step == state.max_steps:
            if args.evaluation_strategy != IntervalStrategy.NO:
                control.should_evaluate = True
            if args.save_strategy != IntervalStrategy.NO:
                control.should_save = True
        return control


class PrinterCallbackNew(TrainerCallback):

    def on_log(self, args, state, control, logs=None, **kwargs):
        logs['global_step'] = state.global_step
        for k, v in logs.items():
            if isinstance(v, float):
                logs[k] = round(logs[k], 8)
        if not is_pai_training_job() and state.is_local_process_zero:
            jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
            with open(jsonl_path, 'a', encoding='utf-8') as f:
                f.write(json.dumps(logs) + '\n')

        _ = logs.pop('total_flos', None)
        if state.is_local_process_zero:
            print(logs, flush=True)