Unverified Commit df5e9c53 authored by YeAnbang's avatar YeAnbang Committed by GitHub
Browse files

[ColossalChat] Update RLHF V2 (#5286)



* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
parent 36c4bb28
from typing import Optional """
loss functions
"""
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -28,9 +31,10 @@ class PolicyLoss(nn.Module): ...@@ -28,9 +31,10 @@ class PolicyLoss(nn.Module):
Policy Loss for PPO Policy Loss for PPO
""" """
def __init__(self, clip_eps: float = 0.2) -> None: def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None:
super().__init__() super().__init__()
self.clip_eps = clip_eps self.clip_eps = clip_eps
self.skip_threshold = skip_threshold
def forward( def forward(
self, self,
...@@ -39,14 +43,20 @@ class PolicyLoss(nn.Module): ...@@ -39,14 +43,20 @@ class PolicyLoss(nn.Module):
advantages: torch.Tensor, advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None, action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp() skip = False
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
# note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold:
skip = True
ratio = ratio_.clamp(0.0, 10.0)
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2) loss = -torch.min(surr1, surr2)
if action_mask is not None: loss = masked_mean(loss, action_mask)
loss = masked_mean(loss, action_mask)
loss = loss.mean() loss = loss.mean()
return loss return loss, skip, ratio_.max()
class ValueLoss(nn.Module): class ValueLoss(nn.Module):
...@@ -54,7 +64,7 @@ class ValueLoss(nn.Module): ...@@ -54,7 +64,7 @@ class ValueLoss(nn.Module):
Value Loss for PPO Value Loss for PPO
""" """
def __init__(self, clip_eps: float = 0.4) -> None: def __init__(self, clip_eps: float = 0.2) -> None:
super().__init__() super().__init__()
self.clip_eps = clip_eps self.clip_eps = clip_eps
...@@ -62,17 +72,82 @@ class ValueLoss(nn.Module): ...@@ -62,17 +72,82 @@ class ValueLoss(nn.Module):
self, self,
values: torch.Tensor, values: torch.Tensor,
old_values: torch.Tensor, old_values: torch.Tensor,
reward: torch.Tensor, advantage: torch.Tensor,
action_mask: Optional[torch.Tensor] = None, action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
returns = advantage + old_values
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - reward) ** 2 surr1 = (values_clipped - returns) ** 2
surr2 = (values - reward) ** 2 surr2 = (values - returns) ** 2
loss = torch.max(surr1, surr2) loss = torch.max(surr1, surr2) / torch.sum(action_mask)
loss = loss.mean() loss = torch.sum(loss * action_mask)
return 0.5 * loss return 0.5 * loss
class DpoLoss(nn.Module):
"""
Dpo loss
Details: https://arxiv.org/pdf/2305.18290.pdf
"""
def __init__(self, beta: float = 0.1):
super().__init__()
self.beta = beta
def forward(
self,
logprob_actor_chosen: torch.Tensor,
logprob_actor_reject: torch.Tensor,
logprob_ref_chosen: torch.Tensor,
logprob_ref_reject: torch.Tensor,
chosen_mask: torch.Tensor,
reject_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328
Args:
logprob_actor_chosen: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
logprob_actor_chosen = logprob_actor_chosen * chosen_mask
logprob_actor_reject = logprob_actor_reject * reject_mask
if logprob_ref_chosen is not None and logprob_ref_reject is not None:
logprob_ref_chosen = logprob_ref_chosen * chosen_mask
logprob_ref_reject = logprob_ref_reject * reject_mask
if len(logprob_ref_chosen.shape) == 2:
ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1)
else:
ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze()
else:
# If no reference model is provided
ref_logratios = 0.0
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
logits = pi_logratios - ref_logratios
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
# Calculate rewards for logging
if logprob_ref_chosen is not None:
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
else:
chosen_rewards = self.beta * logprob_actor_chosen.sum(-1).detach()
if logprob_ref_reject is not None:
rejected_rewards = self.beta * (logprob_actor_reject.sum(-1) - logprob_ref_reject.sum(-1)).detach()
else:
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
return losses, chosen_rewards, rejected_rewards
class LogSigLoss(nn.Module): class LogSigLoss(nn.Module):
""" """
Pairwise Loss for Reward Model Pairwise Loss for Reward Model
...@@ -80,10 +155,7 @@ class LogSigLoss(nn.Module): ...@@ -80,10 +155,7 @@ class LogSigLoss(nn.Module):
""" """
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
probs = torch.sigmoid(chosen_reward - reject_reward) return -torch.nn.functional.logsigmoid(chosen_reward - reject_reward).mean()
log_probs = torch.log(probs)
loss = -log_probs.mean()
return loss
class LogExpLoss(nn.Module): class LogExpLoss(nn.Module):
......
"""
reward model
"""
from typing import Optional
import torch
import torch.nn as nn
from coati.models import BaseModel
from transformers import PretrainedConfig
class RewardModel(BaseModel):
"""
Reward model class.
Args:
pretrained str: huggingface or local model path
config: PretrainedConfig object
**kwargs: all other kwargs as in AutoModel.from_pretrained
"""
def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:
super().__init__(pretrained=pretrained, config=config, **kwargs)
self.value_head = nn.Linear(self.last_hidden_state_size, 1)
self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1))
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(input_ids, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"]
sequence_lengths = torch.max(attention_mask * torch.arange(input_ids.size(1), device=input_ids.device), dim=1)[
0
]
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths].type(
self.value_head.weight.dtype
)
values = self.value_head(sequence_hidden_states).squeeze(-1) # Ensure shape is (B,)
return values
import json
import os
from typing import Any, Dict, Optional, Union
import torch
import torch.nn.functional as F
def get_model_numel(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def compute_reward(
r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
reward_eps=5,
) -> torch.Tensor:
"""
Args:
log_probs: [batch_size, response_length]
log_probs_base: [batch_size, response_length]
action_mask: [batch_size, response_length]
r: float
Returns:
reward: [batch_size, response_length]
"""
log_ratio = log_probs - log_probs_base # address numerical instability issue
kl = -kl_coef * log_ratio * action_mask
reward = kl
r_clip = torch.clamp(r, -reward_eps, reward_eps)
for i in range(action_mask.size(0)):
assert action_mask[i].sum() > 0
reward[i, : action_mask[i].sum()] += r_clip[i]
reward[i, action_mask[i].sum() :] *= 0
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask
def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Compute the log probabilities from logits for the given labels.
Args:
logits (torch.Tensor): The input logits.
labels (torch.Tensor): The target labels.
Returns:
torch.Tensor: The log probabilities corresponding to the labels.
"""
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
"""Calculate action log probs.
Args:
output (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
Returns:
torch.Tensor: Action log probs.
"""
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""
Compute the masked mean of a tensor along a specified dimension.
Args:
tensor (torch.Tensor): The input tensor.
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
dim (int, optional): The dimension along which to compute the mean. Default is 1.
Returns:
torch.Tensor: The masked mean tensor.
"""
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor:
"""
Calculate the masked log probabilities for a given sequence of logits.
Args:
logits (torch.Tensor): The input logits tensor of shape (batch_size, sequence_length, vocab_size).
sequences (torch.LongTensor): The input sequence tensor of shape (batch_size, sequence_length).
mask (torch.Tensor): The mask tensor of shape (batch_size, sequence_length).
Returns:
torch.Tensor: The masked log probabilities tensor of shape (batch_size, sequence_length - 1).
"""
# logits are probabilities of the next token, so we shift them to the left by one
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs * mask
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
"""
with open(file=file_path, mode="r", encoding="utf-8") as fp:
return json.load(fp)
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
"""
Save as JSON format
"""
with open(file=file_path, mode="w", encoding="utf-8") as fp:
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
def disable_dropout(model: torch.nn.Module):
"""
Disables dropout in a PyTorch model. This is used in PPO Training
Args:
model (torch.nn.Module): The PyTorch model.
Returns:
None
"""
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0.0
...@@ -75,7 +75,9 @@ def get_strategy_from_args(strategy: str): ...@@ -75,7 +75,9 @@ def get_strategy_from_args(strategy: str):
elif strategy == "colossalai_zero2": elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu": elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) strategy_ = GeminiStrategy(
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
)
elif strategy == "colossalai_zero2_cpu": elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu") strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment