# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2023-2024 SGLang Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch class AdaptiveKLController: """ Adaptive KL controller described in the paper: https://arxiv.org/pdf/1909.08593.pdf """ def __init__(self, init_kl_coef, target_kl, horizon): self.value = init_kl_coef self.target = target_kl self.horizon = horizon def update(self, current_kl, n_steps): target = self.target proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) mult = 1 + proportional_error * n_steps / self.horizon self.value *= mult class FixedKLController: """Fixed KL controller.""" def __init__(self, kl_coef): self.value = kl_coef def update(self, current_kl, n_steps): pass def get_kl_controller(kl_ctrl): if kl_ctrl.type == "fixed": return FixedKLController(kl_coef=kl_ctrl.kl_coef) elif kl_ctrl.type == "adaptive": assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) else: raise NotImplementedError def compute_onlinedpo_pref( token_level_rewards: torch.Tensor, response_mask: torch.Tensor, ) -> torch.Tensor: """ Computes preferences between pairs of sequences based on summed rewards and returns a mask aligned with the interleaved batch. Assumes inputs are interleaved: [Resp1_Prompt0, Resp2_Prompt0, Resp1_Prompt1, Resp2_Prompt1, ...] Args: token_level_rewards: Tensor of shape [batch_size * 2, seq_len] response_mask: Tensor of shape [batch_size * 2, seq_len] Returns: torch.Tensor: A boolean mask of shape [batch_size * 2], where True indicates the corresponding entry is the chosen response for its pair. Example: [True, False, False, True, ...] means for prompt 0, response 1 was chosen; for prompt 1, response 2 was chosen. """ # print(f"---- [DEBUG] Inside compute_onlinedpo_pref ----") if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0: raise ValueError( f"Input tensor batch dimension must be even for pair comparison, got shapes: " f"{token_level_rewards.shape}, {response_mask.shape}" ) if token_level_rewards.shape != response_mask.shape: raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}") # 1. Calculate Sequence Scores scores = (token_level_rewards * response_mask).sum(dim=-1) # print(f" Calculated sequence scores shape: {scores.shape}") # [batch_size * 2] # 2. Reshape scores to group pairs: [batch_size, 2] try: score_pairs = scores.view(-1, 2) except RuntimeError as e: print(f"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}") raise e print(f" Reshaped score pairs shape: {score_pairs.shape}") # [batch_size, 2] # 3. Compare scores to find which index (0 or 1) is the winner within each pair # winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1 winner_indices = torch.argmax(score_pairs, dim=1) # 0 if first is max, 1 if second is max # Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max) # Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1] # print(f" Winner indices shape: {winner_indices.shape}") # [batch_size] # print(f" Number where Response 2 (index 1) is preferred: {winner_indices.sum().item()}") # Counts number of 1s # 4. Create the final [batch_size * 2] mask num_pairs = score_pairs.shape[0] full_batch_size = num_pairs * 2 # Create indices for the full batch [0, 1, 2, 3, ..., N*2-1] # full_indices = torch.arange(full_batch_size, device=scores.device) # Create indices corresponding to the winner within each pair's original index # E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2] # winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4] pair_indices = torch.arange(num_pairs, device=scores.device) winner_global_indices = (pair_indices * 2) + winner_indices # Create boolean mask - True at the winner's position output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device) output_preference_mask[winner_global_indices] = True # print(f" Output preference mask shape: {output_preference_mask.shape}") # Should be [batch_size * 2] # print(f" Output mask True count (Chosen): {output_preference_mask.sum().item()}") # Should be batch_size # print(f" Output mask False count (Rejected): {(~output_preference_mask).sum().item()}") # Should be batch_size # print(f"---- [DEBUG] Exiting compute_onlinedpo_pref ----") return output_preference_mask def compute_online_dpo_loss( policy_chosen_logps: torch.Tensor, policy_rejected_logps: torch.Tensor, reference_chosen_logps: torch.Tensor, reference_rejected_logps: torch.Tensor, beta: float, label_smoothing: float = 0.0, loss_type: str = "sigmoid", reference_free: bool = False, ) -> torch.Tensor: import torch.nn.functional as F pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps if reference_free: ref_logratios = torch.zeros_like(pi_logratios) logits = pi_logratios - ref_logratios if loss_type == "sigmoid": losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing elif loss_type == "ipo": losses = (logits - 1 / (2 * beta)) ** 2 else: raise ValueError(f"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.") return losses.mean() def get_batch_logps( logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False ) -> torch.FloatTensor: """ Compute the log probabilities of the given labels under the given logits. Args: logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`). Shape: (batch_size, sequence_length, vocab_size) labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length) average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum. Returns: A tensor of shape (batch_size,) containing the average/sum log probabilities of the given sequences. """ if logits.shape[:-1] != labels.shape: raise ValueError("Logits and labels must have the same shape[:-1]") # Ensure labels are contiguous and on the same device as logits labels = labels.contiguous().to(logits.device) # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Calculate per token log probability loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none") per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) per_token_logps = per_token_logps.view( shift_logits.size(0), shift_logits.size(1) ) # Reshape back to (batch_size, seq_len-1) # Create a mask for the labels that are not -100 loss_mask = shift_labels != -100 # Apply the mask to the per token log probabilities masked_logps = per_token_logps * loss_mask # Calculate the sum or average log probability per sequence sequence_logps = masked_logps.sum(dim=-1) if average_log_prob: # Avoid division by zero for sequences with no valid tokens num_valid_tokens = loss_mask.sum(dim=-1) return sequence_logps / torch.clamp(num_valid_tokens, min=1) else: return sequence_logps