sft.py 5.53 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
import math
import time
Hongxin Liu's avatar
Hongxin Liu committed
3
from typing import List, Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
5
6
7

import torch
import torch.distributed as dist
import wandb
8
from torch.optim import Optimizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
9
10
11
12
13
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import get_scheduler

14
from .base import Trainer
Hongxin Liu's avatar
Hongxin Liu committed
15
from .callbacks import Callback
16
17
from .strategies import ColossalAIStrategy, Strategy
from .utils import is_rank_0, to_device
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
18
19


20
class SFTTrainer(Trainer):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
21
22
23
24
25
26
27
28
29
30
31
    """
        Trainer to use while training reward model.

    Args:
        model (torch.nn.Module): the model to train
        strategy (Strategy): the strategy to use for training
        optim(Optimizer): the optimizer to use for training
        train_dataloader: the dataloader to use for training
        eval_dataloader: the dataloader to use for evaluation
        batch_size (int, defaults to 1): the batch size while training
        max_epochs (int, defaults to 2): the number of epochs to train
32
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
33
34
35
36
37
38
39
40
41
42
43
        optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
    """

    def __init__(
        self,
        model,
        strategy: Strategy,
        optim: Optimizer,
        train_dataloader: DataLoader,
        eval_dataloader: DataLoader = None,
        max_epochs: int = 2,
44
        accumulation_steps: int = 8,
45
        callbacks: List[Callback] = [],
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
46
    ) -> None:
47
        if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
48
            raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
49
        super().__init__(strategy, max_epochs, callbacks=callbacks)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
50
51
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
52
53
        self.model = model
        self.optimizer = optim
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
54

55
56
        self.accumulation_steps = accumulation_steps
        num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps
57
        max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
58
59
60
61
62
63

        self.scheduler = get_scheduler("cosine",
                                       self.optimizer,
                                       num_warmup_steps=math.ceil(max_steps * 0.03),
                                       num_training_steps=max_steps)

Hongxin Liu's avatar
Hongxin Liu committed
64
65
66
67
    def fit(self, logger, use_wandb: bool = False):
        if use_wandb:
            wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
            wandb.watch(self.model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
68
69
        total_loss = 0
        # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
70
        step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs),
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
71
72
                        desc=f'steps',
                        disable=not is_rank_0())
73
        for epoch in range(self.max_epochs):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
74
75
76
77
78
79

            # process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
            # train
            self.model.train()
            for batch_id, batch in enumerate(self.train_dataloader):

80
81
                batch = to_device(batch, torch.cuda.current_device())
                outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
82
83
84

                loss = outputs.loss

tingfeng cao's avatar
tingfeng cao committed
85
                if loss >= 2.5 and is_rank_0():
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
86
87
                    logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")

88
                loss = loss / self.accumulation_steps
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
89
90
91
92
93
94

                self.strategy.backward(loss, self.model, self.optimizer)

                total_loss += loss.item()

                # gradient accumulation
95
                if (batch_id + 1) % self.accumulation_steps == 0:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
96
97
98
                    self.strategy.optimizer_step(self.optimizer)
                    self.optimizer.zero_grad()
                    self.scheduler.step()
Hongxin Liu's avatar
Hongxin Liu committed
99
                    if is_rank_0() and use_wandb:
tingfeng cao's avatar
tingfeng cao committed
100
                        wandb.log({
101
                            "loss": total_loss / self.accumulation_steps,
tingfeng cao's avatar
tingfeng cao committed
102
103
104
105
                            "lr": self.scheduler.get_last_lr()[0],
                            "epoch": epoch,
                            "batch_id": batch_id
                        })
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
                    total_loss = 0
                    step_bar.update()

                # if batch_id % log_interval == 0:
                # logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
                # wandb.log({"loss": loss.item()})

                # process_bar.update()

            # eval
            if self.eval_dataloader is not None:
                self.model.eval()
                with torch.no_grad():
                    loss_sum = 0
                    num_seen = 0
                    for batch in self.eval_dataloader:
122
123
124
125
                        batch = to_device(batch, torch.cuda.current_device())
                        outputs = self.model(batch["input_ids"],
                                             attention_mask=batch["attention_mask"],
                                             labels=batch["labels"])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
126
127
128
                        loss = outputs.loss

                        loss_sum += loss.item()
129
                        num_seen += batch["input_ids"].size(0)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
130
131
132

                    loss_mean = loss_sum / num_seen
                    if dist.get_rank() == 0:
133
                        logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
134
135

            # epoch_bar.update()