Unverified Commit df5e9c53 authored by YeAnbang's avatar YeAnbang Committed by GitHub
Browse files

[ColossalChat] Update RLHF V2 (#5286)



* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
parent 36c4bb28
from .base import OnPolicyTrainer, SLTrainer
from .base import OLTrainer, SLTrainer
from .dpo import DPOTrainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]
__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer"]
"""
Base trainers for online and offline training
SLTrainer: supervised learning trainer
pretrain, sft, dpo, reward model training
OLTrainer: online learning trainer
rlhf-ppo
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import List
from typing import Callable, List
import torch.nn as nn
import tqdm
......@@ -8,8 +16,8 @@ from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
from torch.optim import Optimizer
from .callbacks import Callback
from .strategies import Strategy
from colossalai.booster import Booster
from .utils import is_rank_0
......@@ -26,16 +34,18 @@ class SLTrainer(ABC):
def __init__(
self,
strategy: Strategy,
booster: Booster,
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
start_epoch: int = 0,
) -> None:
super().__init__()
self.strategy = strategy
self.booster = booster
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer
self.start_epoch = start_epoch
@abstractmethod
def _train(self, epoch):
......@@ -45,19 +55,20 @@ class SLTrainer(ABC):
def _eval(self, epoch):
raise NotImplementedError()
@abstractmethod
def _before_fit(self):
raise NotImplementedError()
def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()):
for epoch in tqdm.trange(self.start_epoch, self.max_epochs, desc="Epochs", disable=not is_rank_0()):
self._train(epoch)
self._eval(epoch)
class OnPolicyTrainer(ABC):
class OLTrainer(ABC):
"""
Base class for on-policy rl trainers, e.g. PPO.
Base class for online learning trainers, e.g. PPO.
Args:
strategy (Strategy):the strategy to use for training
......@@ -69,14 +80,16 @@ class OnPolicyTrainer(ABC):
def __init__(
self,
strategy: Strategy,
actor_booster: Booster,
critic_booster: Booster,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = [],
callbacks: List[Callable] = [],
) -> None:
super().__init__()
self.strategy = strategy
self.actor_booster = actor_booster
self.critic_booster = critic_booster
self.data_buffer = data_buffer
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
......@@ -141,6 +154,20 @@ class OnPolicyTrainer(ABC):
"""
raise NotImplementedError()
@abstractmethod
def _setup_update_phrase_dataload(self):
"""
Implement this method to setup dataloader for update phase.
"""
raise NotImplementedError()
@abstractmethod
def _save_checkpoint(self, episode: int = 0):
"""
Implement this method to save checkpoint.
"""
raise NotImplementedError()
def _collect_phase(self, collect_step: int):
self._on_make_experience_start()
experience = self._make_experience(collect_step)
......@@ -178,11 +205,10 @@ class OnPolicyTrainer(ABC):
for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
self._collect_phase(collect_step)
if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory)
self._setup_update_phrase_dataload()
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.data_buffer.clear()
if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0:
self._save_checkpoint(episode + 1)
from .base import Callback
from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]
__all__ = ["Callback", "PerformanceEvaluator"]
......@@ -14,9 +14,11 @@ def get_world_size() -> int:
return 1
def print_rank_0(*args, **kwargs) -> None:
def save_eval_result_rank_0(s: str, save_path: str, **kwargs) -> None:
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
with open(save_path, "a+") as f:
train_config = "; ".join([str(kwargs[key]) for key in kwargs])
f.write(train_config + "\n" + s + "\n")
def divide(x: float, y: float) -> float:
......@@ -74,6 +76,8 @@ class PerformanceEvaluator(Callback):
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0,
train_config: Optional[dict] = None,
save_path: Optional[str] = None,
) -> None:
super().__init__()
self.world_size = get_world_size()
......@@ -92,6 +96,8 @@ class PerformanceEvaluator(Callback):
self.make_experience_flop: int = 0
self.learn_num_samples: int = 0
self.learn_flop: int = 0
self.train_config = train_config
self.save_path = save_path
def on_episode_start(self, episode: int) -> None:
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
......@@ -172,12 +178,14 @@ class PerformanceEvaluator(Callback):
make_experience_time_per_sample = divide(avg_make_experience_duration, num_effective_samples)
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0(
save_eval_result_rank_0(
f"Performance summary:\n"
+ f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
+ f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
+ f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
+ f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
+ f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
+ f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
+ f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%",
self.save_path,
**self.train_config,
)
"""
Dpo trainer
"""
from typing import Any, Optional
import torch
from coati.models.loss import DpoLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import SLTrainer
from .utils import is_rank_0, to_device
class DPOTrainer(SLTrainer):
"""
Trainer for PPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
ref_model (Critic): the reference model in ppo algorithm
booster (Strategy): the strategy to use for training
actor_optim (Optimizer): the optimizer to use for actor model
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
max_epochs (int, defaults to 1): the max number of epochs to train
beta (float, defaults to 0.1): the beta parameter in dpo loss
accumulation_steps (int): the number of steps to accumulate gradients
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
save_dir (str): the directory to save checkpoints
coordinator (DistCoordinator): the coordinator to use for distributed logging
"""
def __init__(
self,
actor: Any,
ref_model: Any,
booster: Booster,
actor_optim: Optimizer,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
beta: float = 0.1,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta)
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit(
self,
train_preference_dataloader: DataLoader = None,
eval_preference_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "dpo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _train(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
self.model.train()
self.accumulative_meter.reset()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
reject_loss_mask[:, -1] = False
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"].to(torch.float32)
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"].to(torch.float32)
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,
logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None,
logprob_ref_reject if logprob_ref_reject is not None else None,
chosen_loss_mask[:, 1:],
reject_loss_mask[:, 1:],
)
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
# DPO Loss
loss = losses.mean()
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.accumulative_meter.reset()
if (self.num_train_step + 1) % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.ref_model.eval()
self.coordinator.print_on_master("\nStart evaluation...")
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
self.accumulative_meter.reset()
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"].to(torch.float32)
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
)
self.ref_model.eval()
ref_all_logits = self.ref_model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"].to(torch.float32)
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,
logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None,
logprob_ref_reject if logprob_ref_reject is not None else None,
chosen_loss_mask[:, 1:],
reject_loss_mask[:, 1:],
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
loss = losses.mean()
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
self.accumulative_meter.add(
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
)
step_bar.update()
msg = "Evaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
step_bar.close()
"""
PPO trainer
"""
import os
from typing import Dict, List, Optional
import torch
import wandb
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models import Critic, RewardModel
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from coati.trainer.callbacks import Callback
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import OnPolicyTrainer
from .callbacks import Callback
from .strategies import GeminiStrategy, Strategy
from .base import OLTrainer
from .utils import CycledDataLoader, is_rank_0, to_device
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapped_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapped_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
def _set_default_generate_kwargs(actor: PreTrainedModel) -> Dict:
"""
Set default keyword arguments for generation based on the actor model.
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
Args:
actor (PreTrainedModel): The actor model.
Returns:
Dict: A dictionary containing the default keyword arguments for generation.
"""
unwrapped_model = actor.unwrap()
new_kwargs = {}
# use huggingface models method directly
if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
return new_kwargs
class PPOTrainer(OnPolicyTrainer):
class PPOTrainer(OLTrainer):
"""
Trainer for PPO algorithm.
Args:
strategy (Strategy): the strategy to use for training
strategy (Booster): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
......@@ -61,13 +80,16 @@ class PPOTrainer(OnPolicyTrainer):
def __init__(
self,
strategy: Strategy,
actor: Actor,
actor_booster: Booster,
critic_booster: Booster,
actor: PreTrainedModel,
critic: Critic,
reward_model: RewardModel,
initial_model: Actor,
initial_model: PreTrainedModel,
actor_optim: Optimizer,
critic_optim: Optimizer,
actor_lr_scheduler: _LRScheduler,
critic_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
......@@ -76,25 +98,39 @@ class PPOTrainer(OnPolicyTrainer):
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
value_clip: float = 0.2,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
accumulation_steps: int = 1,
save_interval: int = 0,
save_dir: str = None,
use_tp: bool = False,
coordinator: DistCoordinator = None,
callbacks: List[Callback] = [],
**generate_kwargs,
) -> None:
if isinstance(strategy, GeminiStrategy):
if isinstance(actor_booster, GeminiPlugin):
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
super().__init__(
actor_booster, critic_booster, data_buffer, sample_buffer, dataloader_pin_memory, callbacks=callbacks
)
self.generate_kwargs = _set_default_generate_kwargs(actor)
self.generate_kwargs.update(generate_kwargs)
self.actor = actor
self.critic = critic
self.actor_booster = actor_booster
self.critic_booster = critic_booster
self.actor_scheduler = actor_lr_scheduler
self.critic_scheduler = critic_lr_scheduler
self.tokenizer = tokenizer
self.experience_maker = NaiveExperienceMaker(
self.actor, self.critic, reward_model, initial_model, self.tokenizer, kl_coef
)
self.train_batch_size = train_batch_size
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
......@@ -103,14 +139,21 @@ class PPOTrainer(OnPolicyTrainer):
self.ptx_coef = ptx_coef
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.save_interval = save_interval
self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor")
self.critic_save_dir = os.path.join(save_dir, "critic")
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.use_tp = use_tp
self.accumulative_meter = AccumulativeMeanMeter()
self.offload_inference_models = offload_inference_models
self.device = get_accelerator().get_current_device()
self.device = get_current_device()
def _before_fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
pretrain_dataloader: Optional[DataLoader] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
......@@ -120,14 +163,14 @@ class PPOTrainer(OnPolicyTrainer):
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) if pretrain_dataloader is not None else None
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-ppo", sync_tensorboard=True)
self.wandb_run = wandb.init(project="Coati-ppo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
......@@ -138,48 +181,163 @@ class PPOTrainer(OnPolicyTrainer):
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _setup_update_phrase_dataload(self):
"""
why not use distributed_dataloader?
if tp is used, input on each rank is the same and we use the same dataloader to feed same experience to all ranks
if tp is not used, input on each rank is different and we expect different experiences to be fed to each rank
"""
self.dataloader = DataLoader(
self.data_buffer,
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=self.data_buffer.collate_fn,
)
def _make_experience(self, collect_step: int) -> Experience:
"""
Make experience
"""
prompts = self.prompt_dataloader.next()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
return self.experience_maker.make_experience(
input_ids=prompts["input_ids"].to(get_current_device()),
attention_mask=prompts["attention_mask"].to(get_current_device()),
**self.generate_kwargs,
)
def _training_step(self, experience: Experience):
"""
Args:
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.num_train_step += 1
self.actor.train()
self.critic.train()
# policy loss
num_actions = experience.action_log_probs.size(1)
actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
# policy loss
actor_logits = self.actor(input_ids=experience.sequences, attention_mask=experience.attention_mask)[
"logits"
] # [batch size, prompt_length + response_length]
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
actor_loss = (1 - self.ptx_coef) * actor_loss
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
if not to_skip:
self.actor_booster.backward(loss=actor_loss, optimizer=self.actor_optim)
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
outputs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
ptx_loss = outputs.loss
ptx_loss = self.ptx_coef * ptx_loss
self.actor_booster.backward(loss=ptx_loss, optimizer=self.actor_optim)
# value loss
values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
values = self.critic(
input_ids=experience.sequences, attention_mask=experience.attention_mask
) # [batch size, prompt_length + response_length]
critic_loss = self.critic_loss_fn(
values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
)
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
# sync
actor_loss_mean = all_reduce_mean(tensor=actor_loss)
critic_loss_mean = all_reduce_mean(tensor=critic_loss)
max_ratio_mean = all_reduce_mean(tensor=max_ratio)
reward_mean = all_reduce_mean(tensor=experience.reward.mean())
value_mean = all_reduce_mean(tensor=experience.values.mean())
advantages_mean = all_reduce_mean(tensor=experience.advantages.mean())
kl_mean = all_reduce_mean(tensor=experience.kl.mean())
if self.ptx_coef != 0:
ptx_loss_mean = all_reduce_mean(tensor=ptx_loss)
self.accumulative_meter.add("actor_loss", actor_loss_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("critic_loss", critic_loss_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("max_ratio", max_ratio_mean.to(torch.float16).item())
self.accumulative_meter.add("reward", reward_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("value", value_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("advantages", advantages_mean.to(torch.float16).item())
self.accumulative_meter.add("skip_ratio", 1.0 if to_skip else 0.0)
self.accumulative_meter.add("kl", kl_mean.to(torch.float16).item())
if self.ptx_coef != 0:
self.accumulative_meter.add("ptx_loss", ptx_loss_mean.to(torch.float16).mean().item())
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.actor_optim.step()
self.critic_optim.step()
self.actor_optim.zero_grad()
self.critic_optim.zero_grad()
self.actor_scheduler.step()
self.critic_scheduler.step()
# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 1:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
for i in range(len(response_text)):
response_text[i] = response_text[i] + f"\n\nReward: {experience.reward[i]}"
if self.writer and is_rank_0() and "wandb_run" in self.__dict__:
# log output to wandb
my_table = wandb.Table(
columns=[f"sample response {i}" for i in range(len(response_text))], data=[response_text]
)
try:
self.wandb_run.log({"sample_response": my_table})
except OSError as e:
self.coordinator.print_on_master(e)
elif self.writer and is_rank_0():
for line in response_text:
self.coordinator.print_on_master(line)
if self.writer and is_rank_0():
self.writer.add_scalar("train/max_ratio", self.accumulative_meter.get("max_ratio"), self.num_train_step)
self.writer.add_scalar(
"train/skip_ratio", self.accumulative_meter.get("skip_ratio"), self.num_train_step
)
self.writer.add_scalar(
"train/actor_loss", self.accumulative_meter.get("actor_loss"), self.num_train_step
)
self.writer.add_scalar("train/lr_actor", self.actor_optim.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar("train/lr_critic", self.critic_optim.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/critic_loss", self.accumulative_meter.get("critic_loss"), self.num_train_step
)
if self.ptx_coef != 0:
self.writer.add_scalar(
"train/ptx_loss", self.accumulative_meter.get("ptx_loss"), self.num_train_step
)
self.writer.add_scalar("reward", self.accumulative_meter.get("reward"), self.num_train_step)
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), self.num_train_step)
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
self.accumulative_meter.reset()
def _learn(self, update_step: int):
"""
Perform the learning step of the PPO algorithm.
Args:
update_step (int): The current update step.
Returns:
None
"""
if self.offload_inference_models:
self.experience_maker.initial_model.to("cpu")
self.experience_maker.reward_model.to("cpu")
......@@ -200,3 +358,46 @@ class PPOTrainer(OnPolicyTrainer):
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
def _save_checkpoint(self, episode: int = 0):
"""
Save the actor and critic checkpoints with running states.
Args:
episode (int): The current episode number.
Returns:
None
"""
self.coordinator.print_on_master("\nStart saving actor checkpoint with running states")
save_checkpoint(
save_dir=self.actor_save_dir,
booster=self.actor_booster,
model=self.actor,
optimizer=self.actor_optim,
lr_scheduler=self.actor_scheduler,
epoch=0,
step=episode + 1,
batch_size=self.train_batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved actor checkpoint at episode {(episode + 1)} at folder {self.actor_save_dir}"
)
self.coordinator.print_on_master("\nStart saving critic checkpoint with running states")
save_checkpoint(
save_dir=self.critic_save_dir,
booster=self.critic_booster,
model=self.critic,
optimizer=self.critic_optim,
lr_scheduler=self.critic_scheduler,
epoch=0,
step=episode + 1,
batch_size=self.train_batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved critic checkpoint at episode {(episode + 1)} at folder {self.critic_save_dir}"
)
"""
Reward model trianer
"""
import os
from typing import Any, Callable, Optional
import torch
import tqdm
from coati.models import LogSigLoss
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import SLTrainer
from .utils import is_rank_0, to_device
class RewardModelTrainer(SLTrainer):
"""
Trainer for PPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
ref_model (Critic): the reference model in ppo algorithm
booster (Strategy): the strategy to use for training
actor_optim (Optimizer): the optimizer to use for actor model
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
max_epochs (int, defaults to 1): the max number of epochs to train
beta (float, defaults to 0.1): the beta parameter in dpo loss
accumulation_steps (int): the number of steps to accumulate gradients
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
save_dir (str): the directory to save checkpoints
coordinator (DistCoordinator): the coordinator to use for distributed logging
"""
def __init__(
self,
model: Any,
booster: Booster,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
loss_fn: Optional[Callable] = None,
max_epochs: int = 1,
beta: float = 0.1,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch)
self.actor_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit(
self,
train_preference_dataloader: DataLoader = None,
eval_preference_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-rm", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "rm")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
reject_input_ids,
reject_attention_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
)
batch_size = chosen_input_ids.size()[0]
# Concatenate for better parrallelism
reward = self.model(
torch.cat([chosen_input_ids, reject_input_ids], dim=0),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask], dim=0),
)
chosen_reward = reward[:batch_size]
reject_reward = reward[batch_size:]
loss = self.loss_fn(chosen_reward, reject_reward).mean()
self.booster.backward(loss=loss, optimizer=self.optimizer)
accuracy = (chosen_reward > reject_reward).float()
# Sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_reward)
rejected_rewards_mean = all_reduce_mean(tensor=reject_reward)
accuracy_mean = all_reduce_mean(tensor=accuracy)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", accuracy_mean.mean().to(torch.float16).item())
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
step_bar.update()
self.num_train_step += 1
# Logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/dist",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/reward_reject", self.accumulative_meter.get("rejected_rewards"), self.num_train_step
)
self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), self.num_train_step)
self.accumulative_meter.reset()
# Save checkpoint
if self.save_interval > 0 and (self.num_train_step + 1) % self.save_interval == 0:
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch):
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
step_bar = tqdm.trange(
len(self.eval_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
)
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
reject_input_ids,
reject_attention_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
)
chosen_reward = self.model(chosen_input_ids, attention_mask=chosen_attention_mask)
reject_reward = self.model(reject_input_ids, attention_mask=reject_attention_mask)
loss = self.loss_fn(chosen_reward, reject_reward).mean()
# Sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_reward)
rejected_rewards_mean = all_reduce_mean(tensor=reject_reward)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
step_bar.update()
msg = "Evaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
msg = (
msg
+ f"distance: {self.accumulative_meter.get('chosen_rewards')-self.accumulative_meter.get('rejected_rewards')}\n"
)
self.coordinator.print_on_master(msg)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
"""
SFT trainer
"""
import os
from typing import Optional
import torch
import torch.distributed as dist
import tqdm
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from colossalai.logging import DistributedLogger
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from .base import SLTrainer
from .strategies import GeminiStrategy, Strategy
from .utils import is_rank_0, to_device
......@@ -30,75 +36,31 @@ class SFTTrainer(SLTrainer):
def __init__(
self,
model,
strategy: Strategy,
booster: Booster,
optim: Optimizer,
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
accumulation_steps: int = 8,
start_epoch=0,
save_interval: int = None,
save_dir: str = None,
coordinator: Optional[DistCoordinator] = None,
) -> None:
if accumulation_steps > 1:
assert not isinstance(
strategy, GeminiStrategy
), "Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, model, optim)
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler
self.save_interval = save_interval
self.save_dir = save_dir
self.coordinator = coordinator
self.num_train_step = 0
self.num_eval_step = 0
def _train(self, epoch: int):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss / self.accumulation_steps
self.total_loss += loss.item()
self.strategy.backward(loss, self.model, self.optimizer)
# gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
if self.writer:
self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.total_loss = 0
step_bar.update()
step_bar.close()
def _eval(self, epoch: int):
if self.eval_dataloader is not None:
self.model.eval()
with torch.no_grad():
loss_sum, num_seen = 0, 0
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
)
loss_sum += outputs.loss.item()
num_seen += batch["input_ids"].size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
if self.writer:
self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
self.num_eval_step += 1
self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit(
self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
logger: Optional[DistributedLogger] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
......@@ -106,11 +68,12 @@ class SFTTrainer(SLTrainer):
Args:
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
log_dir: the directory to save logs
use_wandb: whether to use wandb for logging
"""
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.logger = logger
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
......@@ -127,4 +90,81 @@ class SFTTrainer(SLTrainer):
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
self.total_loss = 0
def _train(self, epoch: int):
self.model.train()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss
self.booster.backward(loss=loss, optimizer=self.optimizer)
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.accumulative_meter.reset()
self.model.eval()
with torch.no_grad():
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
"""
Training utilities for Coati.
"""
from typing import Any
import torch
......@@ -8,10 +11,18 @@ from torch.utils.data import DataLoader
class CycledDataLoader:
"""
Why do we need this class?
In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
A data loader that cycles through the data when it reaches the end.
Args:
dataloader (DataLoader): The original data loader.
Attributes:
dataloader (DataLoader): The original data loader.
count (int): The number of times the data loader has been cycled.
dataloader_iter (iterable): The iterator for the data loader.
Methods:
next(): Returns the next batch of data from the data loader, cycling through the data if necessary.
"""
def __init__(
......@@ -24,6 +35,12 @@ class CycledDataLoader:
self.dataloader_iter = None
def next(self):
"""
Returns the next batch of data from the data loader, cycling through the data if necessary.
Returns:
Any: The next batch of data from the data loader.
"""
# defer initialization
if self.dataloader_iter is None:
self.dataloader_iter = iter(self.dataloader)
......@@ -38,13 +55,59 @@ class CycledDataLoader:
def is_rank_0() -> bool:
"""
Check if the current process is the rank 0 process in a distributed training setup.
Returns:
bool: True if the current process is the rank 0 process, False otherwise.
"""
return not dist.is_initialized() or dist.get_rank() == 0
def to_device(x: Any, device: torch.device) -> Any:
"""
Move the input tensor or nested structure of tensors to the specified device.
Args:
x (Any): The input tensor or nested structure of tensors.
device (torch.device): The target device to move the tensors to.
Returns:
Any: The tensor or nested structure of tensors moved to the target device.
"""
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
return t
return tree_map(_to, x)
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
"""
Perform all-reduce operation on the given tensor and compute the mean across all processes.
Args:
tensor (torch.Tensor): The input tensor to be reduced.
Returns:
torch.Tensor: The reduced tensor with mean computed across all processes.
"""
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
"""
Performs an all-reduce operation to sum the values of the given tensor across all processes.
Args:
tensor (torch.Tensor): The input tensor to be reduced.
Returns:
torch.Tensor: The reduced tensor with the sum of values across all processes.
"""
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
return tensor
from .accumulative_meter import AccumulativeMeanMeter
from .ckpt_io import load_checkpoint, save_checkpoint
__all__ = ["load_checkpoint", "save_checkpoint", "AccumulativeMeanMeter"]
"""
A class that can be used to calculate the mean of a variable
"""
class AccumulativeMeanVariable:
"""
A class that calculates the accumulative mean of a variable.
"""
def __init__(self):
self._sum = 0
self._count = 0
def add(self, value, count_update=1):
"""
Adds a value to the sum and updates the count.
Args:
value (float): The value to be added.
count_update (int, optional): The amount to update the count by. Defaults to 1.
"""
self._sum += value
self._count += count_update
def get(self):
"""
Calculates and returns the accumulative mean.
Returns:
float: The accumulative mean.
"""
return self._sum / self._count if self._count > 0 else 0
def reset(self):
"""
Resets the sum and count to zero.
"""
self._sum = 0
self._count = 0
class AccumulativeMeanMeter:
"""
A class for calculating and storing the accumulative mean of variables.
Attributes:
variable_dict (dict): A dictionary to store the accumulative mean variables.
Methods:
add(name, value, count_update=1): Adds a value to the specified variable.
get(name): Retrieves the accumulative mean value of the specified variable.
reset(): Resets all the accumulative mean variables to their initial state.
"""
def __init__(self):
self.variable_dict = {}
def add(self, name, value, count_update=1):
if name not in self.variable_dict:
self.variable_dict[name] = AccumulativeMeanVariable()
self.variable_dict[name].add(value, count_update=count_update)
def get(self, name):
return self.variable_dict[name].get()
def reset(self):
for name in self.variable_dict:
self.variable_dict[name].reset()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Helper functions for IO save load checkpoints
"""
import json
import os
from typing import Any, Dict, Tuple, Union
import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
"""
with open(file=file_path, mode="r", encoding="utf-8") as fp:
return json.load(fp)
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
"""
Save as JSON format
"""
with open(file=file_path, mode="w", encoding="utf-8") as fp:
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
"""
Temporary disable the following as save_optimizer causes all processes to hang in a multi-gpu environment,
working on fixing this bug
"""
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load_checkpoint(
load_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
) -> Tuple[int, int, int]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)
{
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
null
]
}
{
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\\'t know the answer to a question, please don\\'t share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
{
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
{
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
{
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant: ' + bos_token }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
{
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
{
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"system_message": null,
"stop_ids": [
2
]
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment