sft.py 4.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from abc import ABC
from typing import Optional
import loralib as lora
import torch
from chatgpt.dataset import SFTDataset
from chatgpt.models.loss import GPTLMLoss
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import torch.distributed as dist
from .strategies import Strategy
from .utils import is_rank_0
from colossalai.logging import get_dist_logger


class SFTTrainer(ABC):
    """
        Trainer to use while training reward model.

    Args:
        model (torch.nn.Module): the model to train
        strategy (Strategy): the strategy to use for training
        optim(Optimizer): the optimizer to use for training
        train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training
        eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation
        batch_size (int, defaults to 1): the batch size while training
        max_epochs (int, defaults to 2): the number of epochs to train
        optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
    """

    def __init__(
        self,
        model,
        strategy: Strategy,
        optim: Optimizer,
        train_dataset: SFTDataset,
        eval_dataset: SFTDataset,
        sampler: Optional[DistributedSampler] = None,
        batch_size: int = 1,
        max_epochs: int = 2,
    ) -> None:
        super().__init__()
        self.strategy = strategy
        self.epochs = max_epochs
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.sampler = sampler

        self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None),
                                           sampler=sampler, batch_size=batch_size)
        self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size)

        self.model = strategy.setup_model(model)
        if "DDP" in str(self.strategy):
            self.model = self.model.module
        self.loss_fn = GPTLMLoss()
        self.optimizer = strategy.setup_optimizer(optim, self.model)

    def fit(self, logger, use_lora, log_interval=10):
        epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
        for epoch in range(self.epochs):
            if isinstance(self.sampler, DistributedSampler):
                self.sampler.set_epoch(epoch)
            # train
            self.model.train()
            for batch_id, batch in enumerate(self.train_dataloader):
                prompt_ids = batch["input_ids"]
                p_mask = batch["attention_mask"]
                prompt_ids = prompt_ids.squeeze(1).cuda()
                p_mask = p_mask.squeeze(1).cuda()
                prompt_logits = self.model(prompt_ids, attention_mask=p_mask)

                loss = self.loss_fn(prompt_logits, prompt_ids)
                self.strategy.backward(loss, self.model, self.optimizer)
                self.strategy.optimizer_step(self.optimizer)
                self.optimizer.zero_grad()
                if batch_id % log_interval == 0:
                    logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')

            # eval
            self.model.eval()
            with torch.no_grad():
                loss_sum = 0
                num_seen = 0
                for batch in self.eval_dataloader:
                    prompt_ids = batch["input_ids"]
                    p_mask = batch["attention_mask"]
                    prompt_ids = prompt_ids.squeeze(1).cuda()
                    p_mask = p_mask.squeeze(1).cuda()

                    prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
                    loss = self.loss_fn(prompt_logits, prompt_ids)
                    loss_sum += loss.item()
                    num_seen += prompt_ids.size(0)

                loss_mean = loss_sum / num_seen
                if dist.get_rank() == 0:
                    logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
            epoch_bar.update()