core_algos.py 8.44 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import torch


class AdaptiveKLController:
    """
    Adaptive KL controller described in the paper:
    https://arxiv.org/pdf/1909.08593.pdf
    """

    def __init__(self, init_kl_coef, target_kl, horizon):
        self.value = init_kl_coef
        self.target = target_kl
        self.horizon = horizon

    def update(self, current_kl, n_steps):
        target = self.target
        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult


class FixedKLController:
    """Fixed KL controller."""

    def __init__(self, kl_coef):
        self.value = kl_coef

    def update(self, current_kl, n_steps):
        pass


def get_kl_controller(kl_ctrl):
    if kl_ctrl.type == "fixed":
        return FixedKLController(kl_coef=kl_ctrl.kl_coef)
    elif kl_ctrl.type == "adaptive":
        assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
        return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
    else:
        raise NotImplementedError


def compute_onlinedpo_pref(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
) -> torch.Tensor:
    """
    Computes preferences between pairs of sequences based on summed rewards
    and returns a mask aligned with the interleaved batch.

    Assumes inputs are interleaved: [Resp1_Prompt0, Resp2_Prompt0, Resp1_Prompt1, Resp2_Prompt1, ...]

    Args:
        token_level_rewards: Tensor of shape [batch_size * 2, seq_len]
        response_mask: Tensor of shape [batch_size * 2, seq_len]

    Returns:
        torch.Tensor: A boolean mask of shape [batch_size * 2], where True indicates
                      the corresponding entry is the chosen response for its pair.
                      Example: [True, False, False, True, ...] means for prompt 0,
                               response 1 was chosen; for prompt 1, response 2 was chosen.
    """
    # print(f"---- [DEBUG] Inside compute_onlinedpo_pref ----")
    if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0:
        raise ValueError(
            f"Input tensor batch dimension must be even for pair comparison, got shapes: "
            f"{token_level_rewards.shape}, {response_mask.shape}"
        )
    if token_level_rewards.shape != response_mask.shape:
        raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}")

    # 1. Calculate Sequence Scores
    scores = (token_level_rewards * response_mask).sum(dim=-1)
    # print(f"  Calculated sequence scores shape: {scores.shape}") # [batch_size * 2]

    # 2. Reshape scores to group pairs: [batch_size, 2]
    try:
        score_pairs = scores.view(-1, 2)
    except RuntimeError as e:
        print(f"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}")
        raise e
    print(f"  Reshaped score pairs shape: {score_pairs.shape}")  # [batch_size, 2]

    # 3. Compare scores to find which index (0 or 1) is the winner within each pair
    #    winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1
    winner_indices = torch.argmax(score_pairs, dim=1)  # 0 if first is max, 1 if second is max
    # Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max)
    # Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1]
    # print(f"  Winner indices shape: {winner_indices.shape}") # [batch_size]
    # print(f"  Number where Response 2 (index 1) is preferred: {winner_indices.sum().item()}") # Counts number of 1s

    # 4. Create the final [batch_size * 2] mask
    num_pairs = score_pairs.shape[0]
    full_batch_size = num_pairs * 2
    # Create indices for the full batch [0, 1, 2, 3, ..., N*2-1]
    # full_indices = torch.arange(full_batch_size, device=scores.device)
    # Create indices corresponding to the winner within each pair's original index
    # E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2]
    # winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4]
    pair_indices = torch.arange(num_pairs, device=scores.device)
    winner_global_indices = (pair_indices * 2) + winner_indices

    # Create boolean mask - True at the winner's position
    output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device)
    output_preference_mask[winner_global_indices] = True

    # print(f"  Output preference mask shape: {output_preference_mask.shape}") # Should be [batch_size * 2]
    # print(f"  Output mask True count (Chosen): {output_preference_mask.sum().item()}") # Should be batch_size
    # print(f"  Output mask False count (Rejected): {(~output_preference_mask).sum().item()}") # Should be batch_size
    # print(f"---- [DEBUG] Exiting compute_onlinedpo_pref ----")

    return output_preference_mask


def compute_online_dpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    reference_chosen_logps: torch.Tensor,
    reference_rejected_logps: torch.Tensor,
    beta: float,
    label_smoothing: float = 0.0,
    loss_type: str = "sigmoid",
    reference_free: bool = False,
) -> torch.Tensor:
    import torch.nn.functional as F

    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    if reference_free:
        ref_logratios = torch.zeros_like(pi_logratios)

    logits = pi_logratios - ref_logratios

    if loss_type == "sigmoid":
        losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing
    elif loss_type == "ipo":
        losses = (logits - 1 / (2 * beta)) ** 2
    else:
        raise ValueError(f"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.")

    return losses.mean()


def get_batch_logps(
    logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False
) -> torch.FloatTensor:
    """
    Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`).
                Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given sequences.
    """
    if logits.shape[:-1] != labels.shape:
        raise ValueError("Logits and labels must have the same shape[:-1]")

    # Ensure labels are contiguous and on the same device as logits
    labels = labels.contiguous().to(logits.device)
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Calculate per token log probability
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
    per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    per_token_logps = per_token_logps.view(
        shift_logits.size(0), shift_logits.size(1)
    )  # Reshape back to (batch_size, seq_len-1)

    # Create a mask for the labels that are not -100
    loss_mask = shift_labels != -100

    # Apply the mask to the per token log probabilities
    masked_logps = per_token_logps * loss_mask

    # Calculate the sum or average log probability per sequence
    sequence_logps = masked_logps.sum(dim=-1)

    if average_log_prob:
        # Avoid division by zero for sequences with no valid tokens
        num_valid_tokens = loss_mask.sum(dim=-1)
        return sequence_logps / torch.clamp(num_valid_tokens, min=1)
    else:
        return sequence_logps