sft.py 6.55 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
import math
import time
3
from typing import Optional, List
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

import loralib as lora
import torch
import torch.distributed as dist
import wandb
from coati.models.loss import GPTLMLoss
from torch import nn
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import get_scheduler

from colossalai.logging import get_dist_logger

21
22
from .callbacks import Callback
from .base import Trainer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
23
24
25
26
from .strategies import Strategy
from .utils import is_rank_0


27
class SFTTrainer(Trainer):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
28
29
30
31
32
33
34
35
36
37
38
    """
        Trainer to use while training reward model.

    Args:
        model (torch.nn.Module): the model to train
        strategy (Strategy): the strategy to use for training
        optim(Optimizer): the optimizer to use for training
        train_dataloader: the dataloader to use for training
        eval_dataloader: the dataloader to use for evaluation
        batch_size (int, defaults to 1): the batch size while training
        max_epochs (int, defaults to 2): the number of epochs to train
39
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
40
41
42
43
44
45
46
47
48
49
50
51
52
        optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
    """

    def __init__(
        self,
        model,
        strategy: Strategy,
        optim: Optimizer,
        train_dataloader: DataLoader,
        eval_dataloader: DataLoader = None,
        batch_size: int = 1,
        max_epochs: int = 2,
        accimulation_steps: int = 8,
53
        callbacks: List[Callback] = [],
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
54
    ) -> None:
55
        super().__init__(strategy, max_epochs, callbacks=callbacks)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
56
57
58
59
60
61
62
63
64
65
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader

        self.model = strategy.setup_model(model)
        if "DDP" in str(self.strategy):
            self.model = self.model.module
        self.optimizer = strategy.setup_optimizer(optim, self.model)

        self.accimulation_steps = accimulation_steps
        num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
66
        max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
67
68
69
70
71
72
73
74
75
76
77

        self.scheduler = get_scheduler("cosine",
                                       self.optimizer,
                                       num_warmup_steps=math.ceil(max_steps * 0.03),
                                       num_training_steps=max_steps)

    def fit(self, logger, log_interval=10):
        wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        wandb.watch(self.model)
        total_loss = 0
        # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
78
        step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.max_epochs),
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
79
80
                        desc=f'steps',
                        disable=not is_rank_0())
81
        for epoch in range(self.max_epochs):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

            # process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
            # train
            self.model.train()
            for batch_id, batch in enumerate(self.train_dataloader):

                prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
                p_mask = batch["attention_mask"].to(torch.cuda.current_device())
                labels = batch["labels"].to(torch.cuda.current_device())
                # prompt_ids = prompt_ids.squeeze(1).cuda()
                # p_mask = p_mask.squeeze(1).cuda()
                # prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)

                outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)

                loss = outputs.loss
                prompt_logits = outputs.logits

tingfeng cao's avatar
tingfeng cao committed
100
                if loss >= 2.5 and is_rank_0():
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
101
102
103
104
105
106
107
108
109
110
111
112
113
                    logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")

                loss = loss / self.accimulation_steps

                self.strategy.backward(loss, self.model, self.optimizer)

                total_loss += loss.item()

                # gradient accumulation
                if (batch_id + 1) % self.accimulation_steps == 0:
                    self.strategy.optimizer_step(self.optimizer)
                    self.optimizer.zero_grad()
                    self.scheduler.step()
tingfeng cao's avatar
tingfeng cao committed
114
115
116
117
118
119
120
                    if is_rank_0():
                        wandb.log({
                            "loss": total_loss / self.accimulation_steps,
                            "lr": self.scheduler.get_last_lr()[0],
                            "epoch": epoch,
                            "batch_id": batch_id
                        })
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
                    total_loss = 0
                    step_bar.update()

                # if batch_id % log_interval == 0:
                # logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
                # wandb.log({"loss": loss.item()})

                # process_bar.update()

            # eval
            if self.eval_dataloader is not None:
                self.model.eval()
                with torch.no_grad():
                    loss_sum = 0
                    num_seen = 0
                    for batch in self.eval_dataloader:
                        prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
                        p_mask = batch["attention_mask"].to(torch.cuda.current_device())
                        labels = batch["labels"].to(torch.cuda.current_device())
                        # prompt_ids = prompt_ids.squeeze(1).cuda()
                        # p_mask = p_mask.squeeze(1).cuda()

                        outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
                        loss = outputs.loss
                        # prompt_logits = outputs.logits

                        loss_sum += loss.item()
                        num_seen += prompt_ids.size(0)

                    loss_mean = loss_sum / num_seen
                    if dist.get_rank() == 0:
152
                        logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
153
154
155
156
157
158
159
160

            # epoch_bar.update()

    def save_model(self,
                   path: str,
                   only_rank0: bool = False,
                   tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
        self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)