Unverified Commit b0ce5a10 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

[Coati] first commit (#3283)

parent fd6add57
from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic
from .llama_lm import LlamaLM
from .llama_rm import LlamaRM
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
from typing import Optional
import torch
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from ..base import Actor
class LlamaActor(Actor):
"""
Llama Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
else:
model = LlamaForCausalLM(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from ..base import Critic
class LlamaCritic(Critic):
"""
Llama Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
else:
model = LlamaForCausalLM(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
from typing import Optional
from transformers import LlamaConfig, LlamaForCausalLM
from ..base import LM
class LlamaLM(LM):
"""
Llama language model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
else:
model = LlamaForCausalLM(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
from typing import Optional
import torch.nn as nn
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
from ..base import RewardModel
class LlamaRM(RewardModel):
"""
Llama Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
model = LlamaModel(config)
else:
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
import math
from typing import Optional
import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
"""
def __init__(
self,
weight: nn.Parameter,
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.weight = weight
self.bias = bias
out_features, in_features = weight.shape
self.in_features = in_features
self.out_features = out_features
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A')
delattr(self, 'lora_B')
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, lora_linear_wrapper(child, lora_rank))
else:
convert_to_lora_recursively(child, lora_rank)
class LoRAModule(nn.Module):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
This calss will convert all torch.nn.Linear layer to LoraLinear layer.
Args:
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
lora_train_bias (str, optional): Whether LoRA train biases.
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
Defaults to 'none'.
"""
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
def convert_to_lora(self) -> None:
if self.lora_rank <= 0:
return
convert_to_lora_recursively(self, self.lora_rank)
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
from typing import Optional
import torch
import torch.nn as nn
from .utils import masked_mean
class GPTLMLoss(nn.Module):
"""
GPT Language Model Loss
"""
def __init__(self):
super().__init__()
self.loss = nn.CrossEntropyLoss()
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
class PolicyLoss(nn.Module):
"""
Policy Loss for PPO
"""
def __init__(self, clip_eps: float = 0.2) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
if action_mask is not None:
loss = masked_mean(loss, action_mask)
loss = loss.mean()
return loss
class ValueLoss(nn.Module):
"""
Value Loss for PPO
"""
def __init__(self, clip_eps: float = 0.4) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(self,
values: torch.Tensor,
old_values: torch.Tensor,
reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - reward)**2
surr2 = (values - reward)**2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return loss
class PPOPtxActorLoss(nn.Module):
"""
To Do:
PPO-ptx Actor Loss
"""
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
super().__init__()
self.pretrain_coef = pretrain_coef
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
self.pretrain_loss_fn = pretrain_loss_fn
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
lm_logits: torch.Tensor,
lm_input_ids: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
return policy_loss + self.pretrain_coef * lm_loss
class LogSigLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2203.02155
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
probs = torch.sigmoid(chosen_reward - reject_reward)
log_probs = torch.log(probs)
loss = -log_probs.mean()
return loss
class LogExpLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2204.05862
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
return loss
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_lm import OPTLM
from .opt_rm import OPTRM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from ..base import Actor
class OPTActor(Actor):
"""
OPT Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
from typing import Optional
import torch.nn as nn
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTModel
from ..base import Critic
class OPTCritic(Critic):
"""
OPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from ..base import LM
class OPTLM(LM):
"""
OPT language model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
from typing import Optional
import torch.nn as nn
from transformers import OPTConfig, OPTModel
from ..base import RewardModel
class OPTRM(RewardModel):
"""
OPT Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
from typing import Optional, Union
import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_approx_kl(log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
Args:
log_probs: Log probabilities of the new distribution.
log_probs_base: Log probabilities of the base distribution.
action_mask: Mask for actions.
"""
log_ratio = log_probs - log_probs_base
approx_kl = (log_ratio.exp() - 1) - log_ratio
if action_mask is not None:
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
return approx_kl
approx_kl = approx_kl.mean(dim=1)
return approx_kl
def compute_reward(r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if kl_coef <= 0.0:
return r
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
reward = r - kl_coef * kl
return reward
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
tensor = tensor * mask
mean = masked_mean(tensor, mask, dim=dim)
mean_centered = tensor - mean
var = masked_mean(mean_centered**2, mask, dim=dim)
return mean_centered * var.clamp(min=eps).rsqrt()
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
mean = tensor.mean(dim)
mean_centered = tensor - mean
var = (mean_centered**2).mean(dim)
norm = mean_centered * var.clamp(min=eps).rsqrt()
return norm
def convert_to_lora(model: nn.Module,
input_size: int,
output_size: int,
lora_rank: int = 16,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
merge_weights: bool = True):
if lora_rank > min(input_size, output_size):
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
module._modules[name] = lora.Linear(input_size,
output_size,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
fan_in_fan_out=fan_in_fan_out,
merge_weights=merge_weights)
from .base import ReplayBuffer
from .naive import NaiveReplayBuffer
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
from abc import ABC, abstractmethod
from typing import Any
from coati.experience_maker.base import Experience
class ReplayBuffer(ABC):
"""Replay buffer base class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
"""
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
super().__init__()
self.sample_batch_size = sample_batch_size
# limit <= 0 means unlimited
self.limit = limit
@abstractmethod
def append(self, experience: Experience) -> None:
pass
@abstractmethod
def clear(self) -> None:
pass
@abstractmethod
def sample(self) -> Experience:
pass
@abstractmethod
def __len__(self) -> int:
pass
@abstractmethod
def __getitem__(self, idx: int) -> Any:
pass
@abstractmethod
def collate_fn(self, batch: Any) -> Experience:
pass
import random
from typing import List
import torch
from coati.experience_maker.base import Experience
from .base import ReplayBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch
class NaiveReplayBuffer(ReplayBuffer):
"""Naive replay buffer class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
"""
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
super().__init__(sample_batch_size, limit)
self.cpu_offload = cpu_offload
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
# TODO(ver217): add prefetch
self.items: List[BufferItem] = []
@torch.no_grad()
def append(self, experience: Experience) -> None:
if self.cpu_offload:
experience.to_device(torch.device('cpu'))
items = split_experience_batch(experience)
self.items.extend(items)
if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
self.items = self.items[samples_to_remove:]
def clear(self) -> None:
self.items.clear()
@torch.no_grad()
def sample(self) -> Experience:
items = random.sample(self.items, self.sample_batch_size)
experience = make_experience_batch(items)
if self.cpu_offload:
experience.to_device(self.target_device)
return experience
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> BufferItem:
return self.items[idx]
def collate_fn(self, batch) -> Experience:
experience = make_experience_batch(batch)
return experience
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn.functional as F
from coati.experience_maker.base import Experience
@dataclass
class BufferItem:
"""BufferItem is an item of experience data.
Shapes of each tensor:
sequences: (S)
action_log_probs: (A)
values: (1)
reward: (1)
advatanges: (1)
attention_mask: (S)
action_mask: (A)
"A" is the number of actions.
"""
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
vals = torch.unbind(value)
else:
# None
vals = [value for _ in range(batch_size)]
assert batch_size == len(vals)
for i, v in enumerate(vals):
batch_kwargs[i][key] = v
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
return items
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
assert side in ('left', 'right')
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(('action_log_probs', 'action_mask'))
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
batch_data = zero_pad_sequences(vals)
else:
batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data
return Experience(**kwargs)
from .base import Trainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from coati.experience_maker import Experience, ExperienceMaker
from coati.replay_buffer import ReplayBuffer
from torch import Tensor
from torch.utils.data import DistributedSampler
from tqdm import tqdm
from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0
class Trainer(ABC):
"""
Base class for rlhf trainers.
Args:
strategy (Strategy):the strategy to use for training
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
replay_buffer (ReplayBuffer): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(self,
strategy: Strategy,
experience_maker: ExperienceMaker,
replay_buffer: ReplayBuffer,
experience_batch_size: int = 8,
max_epochs: int = 1,
tokenizer: Optional[Callable[[Any], dict]] = None,
sample_replay_buffer: bool = False,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
super().__init__()
self.strategy = strategy
self.experience_maker = experience_maker
self.replay_buffer = replay_buffer
self.experience_batch_size = experience_batch_size
self.max_epochs = max_epochs
self.tokenizer = tokenizer
self.generate_kwargs = generate_kwargs
self.sample_replay_buffer = sample_replay_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
@abstractmethod
def training_step(self, experience: Experience) -> Dict[str, Any]:
pass
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
def _sample_prompts(self, prompts) -> list:
indices = list(range(len(prompts)))
sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
return [prompts[i] for i in sampled_indices]
def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
device = torch.cuda.current_device()
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
experience = self.replay_buffer.sample()
metrics = self.training_step(experience)
pbar.set_postfix(metrics)
else:
for epoch in range(self.max_epochs):
self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(device)
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._on_learn_epoch_end(epoch)
def fit(self,
prompt_dataloader,
pretrain_dataloader,
num_episodes: int = 50000,
max_timesteps: int = 500,
update_timesteps: int = 5000) -> None:
time = 0
self.pretrain_dataloader = pretrain_dataloader
self.prompt_dataloader = prompt_dataloader
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
prompts = next(iter(self.prompt_dataloader))
self._on_make_experience_start()
self.experience_maker.initial_model.to(torch.cuda.current_device())
self.experience_maker.reward_model.to(torch.cuda.current_device())
experience = self._make_experience(prompts)
self._on_make_experience_end(experience)
self.replay_buffer.append(experience)
if time % update_timesteps == 0:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
self._learn()
self.replay_buffer.clear()
self._on_episode_end(episode)
self._on_fit_end()
# TODO(ver217): maybe simplify these code using context
def _on_fit_start(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()
def _on_fit_end(self) -> None:
for callback in self.callbacks:
callback.on_fit_end()
def _on_episode_start(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)
def _on_episode_end(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_end(episode)
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
callback.on_make_experience_start()
def _on_make_experience_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_make_experience_end(experience)
def _on_learn_epoch_start(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_start(epoch)
def _on_learn_epoch_end(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_end(epoch)
def _on_learn_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_learn_batch_start()
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_learn_batch_end(metrics, experience)
from .base import Callback
from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment