from abc import ABC from typing import Optional import loralib as lora import torch from chatgpt.models.loss import GPTLMLoss from torch.optim import Adam, Optimizer from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm import torch.distributed as dist from .strategies import Strategy from .utils import is_rank_0 from colossalai.logging import get_dist_logger class SFTTrainer(ABC): """ Trainer to use while training reward model. Args: model (torch.nn.Module): the model to train strategy (Strategy): the strategy to use for training optim(Optimizer): the optimizer to use for training train_dataloader: the dataloader to use for training eval_dataloader: the dataloader to use for evaluation batch_size (int, defaults to 1): the batch size while training max_epochs (int, defaults to 2): the number of epochs to train optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer """ def __init__( self, model, strategy: Strategy, optim: Optimizer, train_dataloader: DataLoader, eval_dataloader: DataLoader = None, sampler: Optional[DistributedSampler] = None, batch_size: int = 1, max_epochs: int = 2, ) -> None: super().__init__() self.strategy = strategy self.epochs = max_epochs self.sampler = sampler self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader self.model = strategy.setup_model(model) if "DDP" in str(self.strategy): self.model = self.model.module self.loss_fn = GPTLMLoss() self.optimizer = strategy.setup_optimizer(optim, self.model) def fit(self, logger, use_lora, log_interval=10): epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0()) for epoch in range(self.epochs): if isinstance(self.sampler, DistributedSampler): self.sampler.set_epoch(epoch) # train self.model.train() for batch_id, batch in enumerate(self.train_dataloader): prompt_ids = batch["input_ids"] p_mask = batch["attention_mask"] labels = batch["labels"] prompt_ids = prompt_ids.squeeze(1).cuda() p_mask = p_mask.squeeze(1).cuda() # prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels) loss, prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels) # loss = self.loss_fn(prompt_logits, labels) self.strategy.backward(loss, self.model, self.optimizer) self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() if batch_id % log_interval == 0: logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') # eval if self.eval_dataloader is not None: self.model.eval() with torch.no_grad(): loss_sum = 0 num_seen = 0 for batch in self.eval_dataloader: prompt_ids = batch["input_ids"] p_mask = batch["attention_mask"] prompt_ids = prompt_ids.squeeze(1).cuda() p_mask = p_mask.squeeze(1).cuda() prompt_logits = self.model(prompt_ids, attention_mask=p_mask) loss = self.loss_fn(prompt_logits, prompt_ids) loss_sum += loss.item() num_seen += prompt_ids.size(0) loss_mean = loss_sum / num_seen if dist.get_rank() == 0: logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}') epoch_bar.update()