rm.py 5.04 KB
Newer Older
1
from typing import Callable, Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
2
3

import torch
4
import tqdm
5
6
7
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
8

9
from .base import SLTrainer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
10
11
12
13
from .strategies import Strategy
from .utils import is_rank_0


14
class RewardModelTrainer(SLTrainer):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
15
16
17
18
19
20
    """
        Trainer to use while training reward model.

    Args:
        model (torch.nn.Module): the model to train
        strategy (Strategy): the strategy to use for training
21
22
        optim (Optimizer): the optimizer to use for training
        lr_scheduler (_LRScheduler): the lr scheduler to use for training
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
23
24
25
26
27
28
29
30
31
        loss_fn (callable): the loss function to use for training
        max_epochs (int, defaults to 2): the number of epochs to train
    """

    def __init__(
        self,
        model,
        strategy: Strategy,
        optim: Optimizer,
32
33
        lr_scheduler: _LRScheduler,
        loss_fn: Callable,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
34
35
        max_epochs: int = 1,
    ) -> None:
36
        super().__init__(strategy, max_epochs, model, optim)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
37
38

        self.loss_fn = loss_fn
39
        self.scheduler = lr_scheduler
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
40

41
42
        self.num_train_step = 0

43
44
45
    def _eval(self, epoch):
        if self.eval_dataloader is not None:
            self.model.eval()
46
            dist, num_correct, num_samples = 0, 0, 0
47
48
49
50
51
52
53
54
            with torch.no_grad():
                for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
                    chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
                    c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
                    reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
                    r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
                    chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
                    reject_reward = self.model(reject_ids, attention_mask=r_mask)
55
56
                    num_samples += chosen_ids.size(0)
                    num_correct += (chosen_reward > reject_reward).sum().item()
57
58
                    dist += (chosen_reward - reject_reward).mean().item()
                self.dist = dist / len(self.eval_dataloader)
59
                self.acc = num_correct / num_samples
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
60

61
62
63
            if self.writer:
                self.writer.add_scalar("eval/dist", self.dist, epoch)
                self.writer.add_scalar("eval/acc", self.acc, epoch)
64
65
66
67

    def _train(self, epoch):
        self.model.train()
        step_bar = tqdm.trange(
68
            len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
69
70
71
72
73
74
75
76
        )
        for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
            chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
            c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
            reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
            r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
            chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
            reject_reward = self.model(reject_ids, attention_mask=r_mask)
77
78
            loss = self.loss_fn(chosen_reward, reject_reward)
            self.strategy.backward(loss, self.model, self.optimizer)
79
80
            self.strategy.optimizer_step(self.optimizer)
            self.optimizer.zero_grad()
81
82
83
84
85
86
87
88
89
            if self.writer:
                self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
                self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
                self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
                self.writer.add_scalar(
                    "train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
                )
            self.num_train_step += 1
            if self.num_train_step % 100 == 0:
90
91
92
93
                self.scheduler.step()
            step_bar.update()
        step_bar.close()

94
95
96
97
98
99
100
    def _before_fit(
        self,
        train_dataloader: DataLoader,
        eval_dataloader: DataLoader,
        log_dir: Optional[str] = None,
        use_wandb: bool = False,
    ):
101
102
103
104
105
106
107
        """
        Args:
            train_dataloader (DataLoader): the dataloader to use for training
            eval_dataloader (DataLoader): the dataloader to use for evaluation
        """
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

        self.writer = None
        if use_wandb and is_rank_0():
            assert log_dir is not None, "log_dir must be provided when use_wandb is True"
            import wandb

            wandb.init(project="Coati-rm", sync_tensorboard=True)
        if log_dir is not None and is_rank_0():
            import os
            import time

            from torch.utils.tensorboard import SummaryWriter

            log_dir = os.path.join(log_dir, "rm")
            log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
            self.writer = SummaryWriter(log_dir=log_dir)