Unverified Commit bb6196e7 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

remove chatgpt (#3284)

parent b0ce5a10
from typing import Optional
import torch.nn as nn
from transformers import LlamaConfig, LlamaForCausalLM
from ..base import RewardModel
class LlamaRM(RewardModel):
"""
Llama Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
else:
model = LlamaForCausalLM(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, lora_rank, lora_train_bias)
import math
from typing import Optional
import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
"""
def __init__(
self,
weight: nn.Parameter,
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.weight = weight
self.bias = bias
out_features, in_features = weight.shape
self.in_features = in_features
self.out_features = out_features
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A')
delattr(self, 'lora_B')
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, lora_linear_wrapper(child, lora_rank))
else:
convert_to_lora_recursively(child, lora_rank)
class LoRAModule(nn.Module):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
This calss will convert all torch.nn.Linear layer to LoraLinear layer.
Args:
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
lora_train_bias (str, optional): Whether LoRA train biases.
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
Defaults to 'none'.
"""
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
def convert_to_lora(self) -> None:
if self.lora_rank <= 0:
return
convert_to_lora_recursively(self, self.lora_rank)
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
from typing import Optional
import torch
import torch.nn as nn
from .utils import masked_mean
class GPTLMLoss(nn.Module):
"""
GPT Language Model Loss
"""
def __init__(self):
super().__init__()
self.loss = nn.CrossEntropyLoss()
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
class PolicyLoss(nn.Module):
"""
Policy Loss for PPO
"""
def __init__(self, clip_eps: float = 0.2) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
if action_mask is not None:
loss = masked_mean(loss, action_mask)
loss = loss.mean()
return loss
class ValueLoss(nn.Module):
"""
Value Loss for PPO
"""
def __init__(self, clip_eps: float = 0.4) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(self,
values: torch.Tensor,
old_values: torch.Tensor,
reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - reward)**2
surr2 = (values - reward)**2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return loss
class PPOPtxActorLoss(nn.Module):
"""
To Do:
PPO-ptx Actor Loss
"""
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
super().__init__()
self.pretrain_coef = pretrain_coef
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
self.pretrain_loss_fn = pretrain_loss_fn
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
lm_logits: torch.Tensor,
lm_input_ids: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
return policy_loss + self.pretrain_coef * lm_loss
class LogSigLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2203.02155
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
probs = torch.sigmoid(chosen_reward - reject_reward)
log_probs = torch.log(probs)
loss = -log_probs.mean()
return loss
class LogExpLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2204.05862
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
return loss
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
from .opt_lm import OPTLM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from ..base import Actor
class OPTActor(Actor):
"""
OPT Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
from typing import Optional
import torch.nn as nn
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTModel
from ..base import Critic
class OPTCritic(Critic):
"""
OPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from ..base import LM
class OPTLM(LM):
"""
OPT language model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
from typing import Optional
import torch.nn as nn
from transformers import OPTConfig, OPTModel
from ..base import RewardModel
class OPTRM(RewardModel):
"""
OPT Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.word_embed_proj_dim + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
from typing import Optional, Union
import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_approx_kl(log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
Args:
log_probs: Log probabilities of the new distribution.
log_probs_base: Log probabilities of the base distribution.
action_mask: Mask for actions.
"""
log_ratio = log_probs - log_probs_base
approx_kl = (log_ratio.exp() - 1) - log_ratio
if action_mask is not None:
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
return approx_kl
approx_kl = approx_kl.mean(dim=1)
return approx_kl
def compute_reward(r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if kl_coef <= 0.0:
return r
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
reward = r - kl_coef * kl
return reward
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
tensor = tensor * mask
mean = masked_mean(tensor, mask, dim=dim)
mean_centered = tensor - mean
var = masked_mean(mean_centered**2, mask, dim=dim)
return mean_centered * var.clamp(min=eps).rsqrt()
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
mean = tensor.mean(dim)
mean_centered = tensor - mean
var = (mean_centered**2).mean(dim)
norm = mean_centered * var.clamp(min=eps).rsqrt()
return norm
def convert_to_lora(model: nn.Module,
input_size: int,
output_size: int,
lora_rank: int = 16,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
merge_weights: bool = True):
if lora_rank > min(input_size, output_size):
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
module._modules[name] = lora.Linear(input_size,
output_size,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
fan_in_fan_out=fan_in_fan_out,
merge_weights=merge_weights)
from .base import ReplayBuffer
from .naive import NaiveReplayBuffer
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
from abc import ABC, abstractmethod
from typing import Any
from chatgpt.experience_maker.base import Experience
class ReplayBuffer(ABC):
"""Replay buffer base class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
"""
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
super().__init__()
self.sample_batch_size = sample_batch_size
# limit <= 0 means unlimited
self.limit = limit
@abstractmethod
def append(self, experience: Experience) -> None:
pass
@abstractmethod
def clear(self) -> None:
pass
@abstractmethod
def sample(self) -> Experience:
pass
@abstractmethod
def __len__(self) -> int:
pass
@abstractmethod
def __getitem__(self, idx: int) -> Any:
pass
@abstractmethod
def collate_fn(self, batch: Any) -> Experience:
pass
import random
from typing import List
import torch
from chatgpt.experience_maker.base import Experience
from .base import ReplayBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch
class NaiveReplayBuffer(ReplayBuffer):
"""Naive replay buffer class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
"""
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
super().__init__(sample_batch_size, limit)
self.cpu_offload = cpu_offload
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
# TODO(ver217): add prefetch
self.items: List[BufferItem] = []
@torch.no_grad()
def append(self, experience: Experience) -> None:
if self.cpu_offload:
experience.to_device(torch.device('cpu'))
items = split_experience_batch(experience)
self.items.extend(items)
if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
self.items = self.items[samples_to_remove:]
def clear(self) -> None:
self.items.clear()
@torch.no_grad()
def sample(self) -> Experience:
items = random.sample(self.items, self.sample_batch_size)
experience = make_experience_batch(items)
if self.cpu_offload:
experience.to_device(self.target_device)
return experience
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> BufferItem:
return self.items[idx]
def collate_fn(self, batch) -> Experience:
experience = make_experience_batch(batch)
return experience
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn.functional as F
from chatgpt.experience_maker.base import Experience
@dataclass
class BufferItem:
"""BufferItem is an item of experience data.
Shapes of each tensor:
sequences: (S)
action_log_probs: (A)
values: (1)
reward: (1)
advatanges: (1)
attention_mask: (S)
action_mask: (A)
"A" is the number of actions.
"""
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
vals = torch.unbind(value)
else:
# None
vals = [value for _ in range(batch_size)]
assert batch_size == len(vals)
for i, v in enumerate(vals):
batch_kwargs[i][key] = v
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
return items
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
assert side in ('left', 'right')
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(('action_log_probs', 'action_mask'))
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
batch_data = zero_pad_sequences(vals)
else:
batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data
return Experience(**kwargs)
from .base import Trainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from chatgpt.experience_maker import Experience, ExperienceMaker
from chatgpt.replay_buffer import ReplayBuffer
from torch import Tensor
from torch.utils.data import DistributedSampler
from tqdm import tqdm
from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0
class Trainer(ABC):
"""
Base class for rlhf trainers.
Args:
strategy (Strategy):the strategy to use for training
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
replay_buffer (ReplayBuffer): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(self,
strategy: Strategy,
experience_maker: ExperienceMaker,
replay_buffer: ReplayBuffer,
experience_batch_size: int = 8,
max_epochs: int = 1,
tokenizer: Optional[Callable[[Any], dict]] = None,
sample_replay_buffer: bool = False,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
super().__init__()
self.strategy = strategy
self.experience_maker = experience_maker
self.replay_buffer = replay_buffer
self.experience_batch_size = experience_batch_size
self.max_epochs = max_epochs
self.tokenizer = tokenizer
self.generate_kwargs = generate_kwargs
self.sample_replay_buffer = sample_replay_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
@abstractmethod
def training_step(self, experience: Experience) -> Dict[str, Any]:
pass
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
def _sample_prompts(self, prompts) -> list:
indices = list(range(len(prompts)))
sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
return [prompts[i] for i in sampled_indices]
def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
device = torch.cuda.current_device()
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
experience = self.replay_buffer.sample()
metrics = self.training_step(experience)
pbar.set_postfix(metrics)
else:
for epoch in range(self.max_epochs):
self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(device)
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._on_learn_epoch_end(epoch)
def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
time = 0
sampler = self.strategy.setup_sampler(prompts)
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
rand_prompts = sampler.sample(self.experience_batch_size)
if self.tokenizer is not None:
inputs = self.tokenizer(rand_prompts)
else:
inputs = rand_prompts
self._on_make_experience_start()
experience = self._make_experience(inputs)
self._on_make_experience_end(experience)
self.replay_buffer.append(experience)
if time % update_timesteps == 0:
self._learn()
self.replay_buffer.clear()
self._on_episode_end(episode)
self._on_fit_end()
# TODO(ver217): maybe simplify these code using context
def _on_fit_start(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()
def _on_fit_end(self) -> None:
for callback in self.callbacks:
callback.on_fit_end()
def _on_episode_start(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)
def _on_episode_end(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_end(episode)
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
callback.on_make_experience_start()
def _on_make_experience_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_make_experience_end(experience)
def _on_learn_epoch_start(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_start(epoch)
def _on_learn_epoch_end(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_end(epoch)
def _on_learn_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_learn_batch_start()
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_learn_batch_end(metrics, experience)
from .base import Callback
from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
from abc import ABC
from chatgpt.experience_maker import Experience
class Callback(ABC):
"""
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
pass
def on_fit_end(self) -> None:
pass
def on_episode_start(self, episode: int) -> None:
pass
def on_episode_end(self, episode: int) -> None:
pass
def on_make_experience_start(self) -> None:
pass
def on_make_experience_end(self, experience: Experience) -> None:
pass
def on_learn_epoch_start(self, epoch: int) -> None:
pass
def on_learn_epoch_end(self, epoch: int) -> None:
pass
def on_learn_batch_start(self) -> None:
pass
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
pass
from time import time
from typing import Optional
import torch
import torch.distributed as dist
from chatgpt.experience_maker import Experience
from .base import Callback
def get_world_size() -> int:
if dist.is_initialized():
return dist.get_world_size()
return 1
def print_rank_0(*args, **kwargs) -> None:
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
class PerformanceEvaluator(Callback):
"""
Callback for valuate the performance of the model.
Args:
actor_num_params: The number of parameters of the actor model.
critic_num_params: The number of parameters of the critic model.
initial_model_num_params: The number of parameters of the initial model.
reward_model_num_params: The number of parameters of the reward model.
enable_grad_checkpoint: Whether to enable gradient checkpointing.
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(self,
actor_num_params: int,
critic_num_params: int,
initial_model_num_params: int,
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
self.critic_num_params = critic_num_params
self.initial_model_num_params = initial_model_num_params
self.reward_model_num_params = reward_model_num_params
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_episodes = ignore_episodes
self.disable: bool = False
self.make_experience_duration: float = 0.
self.make_experience_start_time: Optional[float] = None
self.make_experience_num_samples: int = 0
self.make_experience_flop: int = 0
self.learn_duration: float = 0.
self.learn_start_time: Optional[float] = None
self.learn_num_samples: int = 0
self.learn_flop: int = 0
def on_episode_start(self, episode: int) -> None:
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
def on_make_experience_start(self) -> None:
if self.disable:
return
self.make_experience_start_time = time()
def on_make_experience_end(self, experience: Experience) -> None:
if self.disable:
return
self.make_experience_duration += time() - self.make_experience_start_time
batch_size, seq_len = experience.sequences.shape
self.make_experience_num_samples += batch_size
# actor generate
num_actions = experience.action_mask.size(1)
input_len = seq_len - num_actions
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
# actor forward
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
# critic forward
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
# initial model forward
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
# reward model forward
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
def on_learn_batch_start(self) -> None:
if self.disable:
return
self.learn_start_time = time()
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
if self.disable:
return
self.learn_duration += time() - self.learn_start_time
batch_size, seq_len = experience.sequences.shape
self.learn_num_samples += batch_size
# actor forward-backward, 3 means forward(1) + backward(2)
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
# critic foward-backward
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
def on_fit_end(self) -> None:
avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size)
avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size)
avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12)
avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12)
print_rank_0(
f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}'
)
print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}')
import os
import torch.distributed as dist
from chatgpt.trainer.strategies import ColossalAIStrategy, Strategy
from chatgpt.trainer.utils import is_rank_0
from torch import nn
from torch.optim import Optimizer
from .base import Callback
class SaveCheckpoint(Callback):
"""
The callback for saving checkpoint for chatgpt.
Only support saving actor and critic model.
A typical architecture of the saved checkpoint would be:
- checkpoint
- episode_x
- actor.pt
- actor-optim-rank-0.pt
- actor-optim-rank-1.pt
- critic.pt
- critic-optim-rank-0.pt
- critic-optim-rank-1.pt
- ...
Args:
path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint`
interval(int): the interval episode of saving checkpoint
strategy(Strategy): the strategy used to train
actor(nn.Module): the actor model
critic(nn.Module): the critic model
actor_optim(Optimizer): the optimizer of actor
critic_optim(Optimizer): the optimizer of critic
"""
def __init__(self,
path: str,
interval: int,
strategy: Strategy,
actor: nn.Module = None,
critic: nn.Module = None,
actor_optim: Optimizer = None,
critic_optim: Optimizer = None) -> None:
super().__init__()
self.path = os.path.join(path, 'checkpoint')
self.interval = interval
self.strategy = strategy
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0:
return
base_path = os.path.join(self.path, f'episode_{episode}')
if not os.path.exists(base_path):
os.makedirs(base_path)
for model in self.model_dict.keys():
# save model
if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped
continue
model_path = os.path.join(base_path, f'{model}.pt')
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer
if self.model_dict[model][1] is None:
continue
only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
rank = 0 if is_rank_0() else dist.get_rank()
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
from typing import Any, Callable, Dict, List, Optional
import torch.nn as nn
from chatgpt.experience_maker import Experience, NaiveExperienceMaker
from chatgpt.models.base import Actor, Critic
from chatgpt.models.generation_utils import update_model_kwargs_fn
from chatgpt.models.loss import PolicyLoss, ValueLoss
from chatgpt.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer
from .base import Trainer
from .callbacks import Callback
from .strategies import Strategy
class PPOTrainer(Trainer):
"""
Trainer for PPO algorithm.
Args:
strategy (Strategy): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logits to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
critic_optim (Optimizer): the optimizer to use for critic model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
tokenier (Callable, optional): the tokenizer to use for tokenizing the input
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(self,
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: nn.Module,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
kl_coef: float = 0.1,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
value_clip: float = 0.4,
experience_batch_size: int = 8,
max_epochs: int = 1,
tokenizer: Optional[Callable[[Any], dict]] = None,
sample_replay_buffer: bool = False,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
self.actor = actor
self.critic = critic
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
self.actor_optim = actor_optim
self.critic_optim = critic_optim
def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
self.critic.train()
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
values = self.critic(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
return new_kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment