import torch import torch.nn.functional as F from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def preference_loss_fn( chosen_logps, rejected_logps, full_target, ref_chosen_logps=None, ref_rejected_logps=None, beta=0.1, loss_type="sigmoid", ): """ Paper: https://arxiv.org/pdf/2305.18290 Formula: L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ] Where: - π(y|x): Policy (model) probability - π_ref(y|x): Reference model probability - y_w: Chosen sequence - y_l: Rejected sequence - β: Weight for the direct preference loss - E: Expected value over the dataset Args: chosen_logps: Log probabilities of chosen tokens (batch_size,) rejected_logps: Log probabilities of rejected tokens (batch_size,) full_target: Non chunked full target tensor ref_chosen_logps: Reference log probs of chosen tokens (batch_size,) ref_rejected_logps: Reference log probs of rejected tokens (batch_size,) beta: Weight for the direct preference loss """ if ref_chosen_logps is None: ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device) if ref_rejected_logps is None: ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device) chosen_logratios = chosen_logps - ref_chosen_logps rejected_logratios = rejected_logps - ref_rejected_logps chosen_rewards = beta * chosen_logratios rejected_rewards = beta * rejected_logratios if loss_type == "sigmoid": logits_diff = beta * (chosen_logratios - rejected_logratios) loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2) elif loss_type == "apo_zero": # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) # Use this loss when you believe the chosen outputs are better than your model's default output losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood losses_rejected = F.sigmoid(beta * rejected_logratios) losses = losses_chosen + losses_rejected loss = losses.sum() / (full_target.shape[0] // 2) elif loss_type == "apo_down": # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) # Use this loss when you believe the chosen outputs are worse than your model's default output. # Decrease chosen likelihood and decrease rejected likelihood more losses_chosen = F.sigmoid(beta * chosen_logratios) losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios)) losses = losses_chosen + losses_rejected loss = losses.sum() / (full_target.shape[0] // 2) elif loss_type == "sppo_hard": # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is # set to 1 for the winner and 0 for the loser. a = chosen_logps - ref_chosen_logps b = rejected_logps - ref_rejected_logps losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2 loss = losses.sum() / (full_target.shape[0] // 2) elif loss_type == "nca_pair": losses = ( -F.logsigmoid(chosen_rewards) - 0.5 * F.logsigmoid(-chosen_rewards) - 0.5 * F.logsigmoid(-rejected_rewards) ) loss = losses.sum() / (full_target.shape[0] // 2) else: raise ValueError( f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair" ) return loss, chosen_rewards, rejected_rewards @classmethod def forward( cls, ctx, _input, weight, target, bias=None, ref_input=None, ref_weight=None, ref_bias=None, ignore_index=-100, beta=0.1, compute_nll_loss=False, compiled=True, use_ref_model=True, average_log_prob=False, chunk_size=1, loss_type="sigmoid", ): """ Fused linear layer with DPO loss. Args: _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size) weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size) target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,) bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,) ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size) ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size) ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) ignore_index (int): Index to ignore in loss computation beta (float): Weight for the odds ratio loss compute_nll_loss (bool): Whether to compute the NLL loss compiled (bool): Whether to use torch compile use_ref_model (bool): Whether to use a reference model average_log_prob (bool): Whether to average the log probability per non-masked token chunk_size (int): Size of chunks for processing. Returns: torch.Tensor: Computed loss """ return super().forward( cls=cls, ctx=ctx, _input=_input, weight=weight, target=target, bias=bias, ignore_index=ignore_index, beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, use_ref_model=use_ref_model, ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, average_log_prob=average_log_prob, chunk_size=chunk_size, loss_type=loss_type, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] return *grads, None, None, None, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): """ Fused linear layer with DPO loss. """ def __init__( self, ignore_index: int = -100, beta: float = 0.1, compute_nll_loss: bool = False, compiled: bool = True, use_ref_model: bool = True, average_log_prob: bool = False, chunk_size: int = 1, loss_type: str = "sigmoid", ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. average_log_prob (bool): Whether to average the log probability per non-masked token. chunk_size (int): Size of chunks for processing. """ super().__init__() self.ignore_index = ignore_index self.beta = beta self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model self.average_log_prob = average_log_prob self.chunk_size = chunk_size self.loss_type = loss_type supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"} if self.loss_type not in supported_loss_types: raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}") def forward( self, lin_weight, _input, target, bias=None, ref_input=None, ref_weight=None, ref_bias=None, ): return LigerFusedLinearDPOFunction.apply( _input, lin_weight, target, bias, ref_input, ref_weight, ref_bias, self.ignore_index, self.beta, self.compute_nll_loss, self.compiled, self.use_ref_model, self.average_log_prob, self.chunk_size, self.loss_type, )