pre_trainer.py 2.98 KB
Newer Older
HHL's avatar
v  
HHL committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from torch import nn
from libs.configs.default import counter
from transformers import Trainer
from typing import Any, Dict, Union
from torch.utils.data.distributed import DistributedSampler
from libs.utils.comm import distributed, get_rank, get_world_size
from transformers.trainer import *
from .nan_detector import NanDetector

class PreTrainer(Trainer):

    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
        for k, v in inputs.items():
            if hasattr(v, "to") and hasattr(v, "device"):
                inputs[k] = v.to(self.args.device)

        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past

        return inputs
    
    def get_train_dataloader(self):
        
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        
        if distributed():
            sampler = DistributedSampler(self.train_dataset, get_world_size(), get_rank(), True)
            dataloader = torch.utils.data.DataLoader(
                self.train_dataset,
                sampler=sampler,
                num_workers=self.args.dataloader_num_workers,
                batch_size=self.args.train_batch_size,
                collate_fn=self.data_collator,
                drop_last=self.args.dataloader_drop_last,
                pin_memory=self.args.dataloader_pin_memory,
            )
        else:
            dataloader = torch.utils.data.DataLoader(
                self.train_dataset,
                num_workers=self.args.dataloader_num_workers,
                batch_size=self.args.train_batch_size,
                collate_fn=self.data_collator,
                shuffle=True,
                drop_last=self.args.dataloader_drop_last,
                pin_memory=self.args.dataloader_pin_memory
            )
        return dataloader
    
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
        if self.control.should_log:
            logs: Dict[str, float] = {}
            tr_loss_scalar = tr_loss.item()
            # reset tr_loss to zero
            tr_loss -= tr_loss

            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
            logs["learning_rate"] = round(self._get_learning_rate(),10)
            logs["cuda_max_memory"] = int(torch.cuda.max_memory_allocated()/1024/1024)
            logs = dict(logs, **counter.dict_mean(sync=False))
            
            self._total_loss_scalar += tr_loss_scalar
            self._globalstep_last_logged = self.state.global_step

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
            metrics = self.evaluate()
            self._report_to_hp_search(trial, epoch, metrics)

        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)