rm.py 4.26 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
from datetime import datetime
2
from typing import Callable
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
3
4
5

import pandas as pd
import torch
6
import tqdm
7
8
9
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
10

11
from .base import SLTrainer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
12
13
14
15
from .strategies import Strategy
from .utils import is_rank_0


16
class RewardModelTrainer(SLTrainer):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
17
18
19
20
21
22
    """
        Trainer to use while training reward model.

    Args:
        model (torch.nn.Module): the model to train
        strategy (Strategy): the strategy to use for training
23
24
        optim (Optimizer): the optimizer to use for training
        lr_scheduler (_LRScheduler): the lr scheduler to use for training
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
25
26
27
28
29
30
31
32
33
        loss_fn (callable): the loss function to use for training
        max_epochs (int, defaults to 2): the number of epochs to train
    """

    def __init__(
        self,
        model,
        strategy: Strategy,
        optim: Optimizer,
34
35
        lr_scheduler: _LRScheduler,
        loss_fn: Callable,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
36
37
        max_epochs: int = 1,
    ) -> None:
38
        super().__init__(strategy, max_epochs, model, optim)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
39
40

        self.loss_fn = loss_fn
41
        self.scheduler = lr_scheduler
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    def _eval(self, epoch):
        if self.eval_dataloader is not None:
            self.model.eval()
            dist, on, cnt = 0, 0, 0
            with torch.no_grad():
                for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
                    chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
                    c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
                    reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
                    r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
                    chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
                    reject_reward = self.model(reject_ids, attention_mask=r_mask)
                    for i in range(len(chosen_reward)):
                        cnt += 1
                        if chosen_reward[i] > reject_reward[i]:
                            on += 1
                    dist += (chosen_reward - reject_reward).mean().item()
                self.dist = dist / len(self.eval_dataloader)
                self.acc = on / cnt
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
62
63

            if is_rank_0():
64
                log = pd.DataFrame(
65
66
                    [[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
                    columns=["step", "loss", "dist", "acc"],
67
                )
68
                log.to_csv("log.csv", mode="a", header=False, index=False)
69
70
71
72

    def _train(self, epoch):
        self.model.train()
        step_bar = tqdm.trange(
73
            len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        )
        cnt = 0
        for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
            chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
            c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
            reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
            r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
            chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
            reject_reward = self.model(reject_ids, attention_mask=r_mask)
            self.loss = self.loss_fn(chosen_reward, reject_reward)
            self.strategy.backward(self.loss, self.model, self.optimizer)
            self.strategy.optimizer_step(self.optimizer)
            self.optimizer.zero_grad()
            cnt += 1
            if cnt % 100 == 0:
                self.scheduler.step()
            step_bar.update()
        step_bar.close()

93
    def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
94
95
96
97
98
99
100
        """
        Args:
            train_dataloader (DataLoader): the dataloader to use for training
            valid_dataloader (DataLoader): the dataloader to use for validation
            eval_dataloader (DataLoader): the dataloader to use for evaluation
        """
        super()._before_fit()
101
        self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
102
103
104
105

        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.eval_dataloader = eval_dataloader