rm.py 5.24 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
from datetime import datetime
2
from typing import List, Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
3
4
5
6
7
8
9
10
11

import pandas as pd
import torch
import torch.distributed as dist
from torch.optim import Optimizer, lr_scheduler
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

12
from .base import Trainer
13
from .callbacks import Callback
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
14
15
16
17
from .strategies import Strategy
from .utils import is_rank_0


18
class RewardModelTrainer(Trainer):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
19
20
21
22
23
24
25
26
    """
        Trainer to use while training reward model.

    Args:
        model (torch.nn.Module): the model to train
        strategy (Strategy): the strategy to use for training
        optim(Optimizer): the optimizer to use for training
        loss_fn (callable): the loss function to use for training
27
28
29
        train_dataloader (DataLoader): the dataloader to use for training
        valid_dataloader (DataLoader): the dataloader to use for validation
        eval_dataloader (DataLoader): the dataloader to use for evaluation
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
30
31
        batch_size (int, defaults to 1): the batch size while training
        max_epochs (int, defaults to 2): the number of epochs to train
32
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
33
34
35
36
37
38
39
40
    """

    def __init__(
        self,
        model,
        strategy: Strategy,
        optim: Optimizer,
        loss_fn,
41
42
43
        train_dataloader: DataLoader,
        valid_dataloader: DataLoader,
        eval_dataloader: DataLoader,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
44
        max_epochs: int = 1,
45
        callbacks: List[Callback] = [],
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
46
    ) -> None:
47
        super().__init__(strategy, max_epochs, callbacks=callbacks)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
48

49
50
51
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.eval_dataloader = eval_dataloader
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
52

53
        self.model = model
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
54
        self.loss_fn = loss_fn
55
        self.optimizer = optim
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100)

    def eval_acc(self, dataloader):
        dist = 0
        on = 0
        cnt = 0
        self.model.eval()
        with torch.no_grad():
            for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
                chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
                c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
                reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
                r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
                chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
                reject_reward = self.model(reject_ids, attention_mask=r_mask)
                for i in range(len(chosen_reward)):
                    cnt += 1
                    if chosen_reward[i] > reject_reward[i]:
                        on += 1
                dist += (chosen_reward - reject_reward).mean().item()
            dist_mean = dist / len(dataloader)
            acc = on / cnt
        self.model.train()
        return dist_mean, acc

    def fit(self):
        time = datetime.now()
83
84
        epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
        for epoch in range(self.max_epochs):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            step_bar = tqdm(range(self.train_dataloader.__len__()),
                            desc='Train step of epoch %d' % epoch,
                            disable=not is_rank_0())
            # train
            self.model.train()
            cnt = 0
            acc = 0
            dist = 0
            for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
                chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
                c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
                reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
                r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
                chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
                reject_reward = self.model(reject_ids, attention_mask=r_mask)
                loss = self.loss_fn(chosen_reward, reject_reward)
                self.strategy.backward(loss, self.model, self.optimizer)
                self.strategy.optimizer_step(self.optimizer)
                self.optimizer.zero_grad()
                cnt += 1
                if cnt == 100:
                    self.scheduler.step()
                    dist, acc = self.eval_acc(self.valid_dataloader)
                    cnt = 0
                    if is_rank_0():
                        log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
                                           columns=['step', 'loss', 'dist', 'acc'])
                        log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
                step_bar.update()
                step_bar.set_postfix({'dist': dist, 'acc': acc})

            # eval
            dist, acc = self.eval_acc(self.eval_dataloader)
            if is_rank_0():
                log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
                log.to_csv('log.csv', mode='a', header=False, index=False)
            epoch_bar.update()
            step_bar.set_postfix({'dist': dist, 'acc': acc})
            step_bar.close()