"applications/Chat/coati/trainer/sft.py" did not exist on "fd6add575d87728dbf27f682495fbbbe46c4f5bb"
sft.py 6.6 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
import math
import time
Hongxin Liu's avatar
Hongxin Liu committed
3
from typing import List, Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

import loralib as lora
import torch
import torch.distributed as dist
import wandb
from coati.models.loss import GPTLMLoss
from torch import nn
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import get_scheduler

from colossalai.logging import get_dist_logger

21
from .base import Trainer
Hongxin Liu's avatar
Hongxin Liu committed
22
from .callbacks import Callback
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
23
24
25
26
from .strategies import Strategy
from .utils import is_rank_0


27
class SFTTrainer(Trainer):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
28
29
30
31
32
33
34
35
36
37
38
    """
        Trainer to use while training reward model.

    Args:
        model (torch.nn.Module): the model to train
        strategy (Strategy): the strategy to use for training
        optim(Optimizer): the optimizer to use for training
        train_dataloader: the dataloader to use for training
        eval_dataloader: the dataloader to use for evaluation
        batch_size (int, defaults to 1): the batch size while training
        max_epochs (int, defaults to 2): the number of epochs to train
39
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
40
41
42
43
44
45
46
47
48
49
50
51
52
        optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
    """

    def __init__(
        self,
        model,
        strategy: Strategy,
        optim: Optimizer,
        train_dataloader: DataLoader,
        eval_dataloader: DataLoader = None,
        batch_size: int = 1,
        max_epochs: int = 2,
        accimulation_steps: int = 8,
53
        callbacks: List[Callback] = [],
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
54
    ) -> None:
55
        super().__init__(strategy, max_epochs, callbacks=callbacks)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
56
57
58
59
60
61
62
63
64
65
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader

        self.model = strategy.setup_model(model)
        if "DDP" in str(self.strategy):
            self.model = self.model.module
        self.optimizer = strategy.setup_optimizer(optim, self.model)

        self.accimulation_steps = accimulation_steps
        num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
66
        max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
67
68
69
70
71
72

        self.scheduler = get_scheduler("cosine",
                                       self.optimizer,
                                       num_warmup_steps=math.ceil(max_steps * 0.03),
                                       num_training_steps=max_steps)

Hongxin Liu's avatar
Hongxin Liu committed
73
74
75
76
    def fit(self, logger, use_wandb: bool = False):
        if use_wandb:
            wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
            wandb.watch(self.model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
77
78
        total_loss = 0
        # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
79
        step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.max_epochs),
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
80
81
                        desc=f'steps',
                        disable=not is_rank_0())
82
        for epoch in range(self.max_epochs):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

            # process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
            # train
            self.model.train()
            for batch_id, batch in enumerate(self.train_dataloader):

                prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
                p_mask = batch["attention_mask"].to(torch.cuda.current_device())
                labels = batch["labels"].to(torch.cuda.current_device())
                # prompt_ids = prompt_ids.squeeze(1).cuda()
                # p_mask = p_mask.squeeze(1).cuda()
                # prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)

                outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)

                loss = outputs.loss
                prompt_logits = outputs.logits

tingfeng cao's avatar
tingfeng cao committed
101
                if loss >= 2.5 and is_rank_0():
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
102
103
104
105
106
107
108
109
110
111
112
113
114
                    logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")

                loss = loss / self.accimulation_steps

                self.strategy.backward(loss, self.model, self.optimizer)

                total_loss += loss.item()

                # gradient accumulation
                if (batch_id + 1) % self.accimulation_steps == 0:
                    self.strategy.optimizer_step(self.optimizer)
                    self.optimizer.zero_grad()
                    self.scheduler.step()
Hongxin Liu's avatar
Hongxin Liu committed
115
                    if is_rank_0() and use_wandb:
tingfeng cao's avatar
tingfeng cao committed
116
117
118
119
120
121
                        wandb.log({
                            "loss": total_loss / self.accimulation_steps,
                            "lr": self.scheduler.get_last_lr()[0],
                            "epoch": epoch,
                            "batch_id": batch_id
                        })
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                    total_loss = 0
                    step_bar.update()

                # if batch_id % log_interval == 0:
                # logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
                # wandb.log({"loss": loss.item()})

                # process_bar.update()

            # eval
            if self.eval_dataloader is not None:
                self.model.eval()
                with torch.no_grad():
                    loss_sum = 0
                    num_seen = 0
                    for batch in self.eval_dataloader:
                        prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
                        p_mask = batch["attention_mask"].to(torch.cuda.current_device())
                        labels = batch["labels"].to(torch.cuda.current_device())
                        # prompt_ids = prompt_ids.squeeze(1).cuda()
                        # p_mask = p_mask.squeeze(1).cuda()

                        outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
                        loss = outputs.loss
                        # prompt_logits = outputs.logits

                        loss_sum += loss.item()
                        num_seen += prompt_ids.size(0)

                    loss_mean = loss_sum / num_seen
                    if dist.get_rank() == 0:
153
                        logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
154
155
156
157
158
159
160
161

            # epoch_bar.update()

    def save_model(self,
                   path: str,
                   only_rank0: bool = False,
                   tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
        self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)