sft.py 4.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from abc import ABC
from typing import Optional
import loralib as lora
import torch
from chatgpt.models.loss import GPTLMLoss
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import torch.distributed as dist
from .strategies import Strategy
from .utils import is_rank_0
from colossalai.logging import get_dist_logger


class SFTTrainer(ABC):
    """
        Trainer to use while training reward model.

    Args:
        model (torch.nn.Module): the model to train
        strategy (Strategy): the strategy to use for training
        optim(Optimizer): the optimizer to use for training
24
25
        train_dataloader: the dataloader to use for training
        eval_dataloader: the dataloader to use for evaluation
26
27
28
29
30
31
32
33
34
35
        batch_size (int, defaults to 1): the batch size while training
        max_epochs (int, defaults to 2): the number of epochs to train
        optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
    """

    def __init__(
        self,
        model,
        strategy: Strategy,
        optim: Optimizer,
36
37
        train_dataloader: DataLoader,
        eval_dataloader: DataLoader = None,
38
39
40
41
42
43
44
45
46
        sampler: Optional[DistributedSampler] = None,
        batch_size: int = 1,
        max_epochs: int = 2,
    ) -> None:
        super().__init__()
        self.strategy = strategy
        self.epochs = max_epochs
        self.sampler = sampler

47
48
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        self.model = strategy.setup_model(model)
        if "DDP" in str(self.strategy):
            self.model = self.model.module
        self.loss_fn = GPTLMLoss()
        self.optimizer = strategy.setup_optimizer(optim, self.model)

    def fit(self, logger, use_lora, log_interval=10):
        epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
        for epoch in range(self.epochs):
            if isinstance(self.sampler, DistributedSampler):
                self.sampler.set_epoch(epoch)
            # train
            self.model.train()
            for batch_id, batch in enumerate(self.train_dataloader):
                prompt_ids = batch["input_ids"]
                p_mask = batch["attention_mask"]
66
                labels = batch["labels"]
67
68
                prompt_ids = prompt_ids.squeeze(1).cuda()
                p_mask = p_mask.squeeze(1).cuda()
69
70
                # prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
                loss, prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
71

72
                # loss = self.loss_fn(prompt_logits, labels)
73
74
75
76
77
78
79
                self.strategy.backward(loss, self.model, self.optimizer)
                self.strategy.optimizer_step(self.optimizer)
                self.optimizer.zero_grad()
                if batch_id % log_interval == 0:
                    logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')

            # eval
80
81
82
83
84
85
86
87
88
89
            if self.eval_dataloader is not None:
                self.model.eval()
                with torch.no_grad():
                    loss_sum = 0
                    num_seen = 0
                    for batch in self.eval_dataloader:
                        prompt_ids = batch["input_ids"]
                        p_mask = batch["attention_mask"]
                        prompt_ids = prompt_ids.squeeze(1).cuda()
                        p_mask = p_mask.squeeze(1).cuda()
90

91
92
93
94
                        prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
                        loss = self.loss_fn(prompt_logits, prompt_ids)
                        loss_sum += loss.item()
                        num_seen += prompt_ids.size(0)
95

96
97
98
99
                    loss_mean = loss_sum / num_seen
                    if dist.get_rank() == 0:
                        logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
                        
100
101
            epoch_bar.update()