core_algos.py 14.8 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2022 The HuggingFace Team
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO
"""

chenych's avatar
chenych committed
21
from abc import ABC, abstractmethod
chenych's avatar
chenych committed
22
23
24
25
26
from collections import defaultdict
from typing import TYPE_CHECKING, Tuple

import numpy as np
import torch
chenych's avatar
chenych committed
27
import torch.nn.functional as F
chenych's avatar
chenych committed
28

chenych's avatar
chenych committed
29
from ..utils import torch_functional as VF
chenych's avatar
chenych committed
30
31
32


if TYPE_CHECKING:
chenych's avatar
chenych committed
33
    from .config import AlgorithmConfig
chenych's avatar
chenych committed
34
35


chenych's avatar
chenych committed
36
class KLController(ABC):
chenych's avatar
Update  
chenych committed
37
38
39
    kl_coef: float
    """KL coefficient."""

chenych's avatar
chenych committed
40
    @abstractmethod
chenych's avatar
Update  
chenych committed
41
42
43
    def update(self, current_kl: float, n_steps: int) -> None:
        """Update kl_coef according to current KL."""
        ...
chenych's avatar
chenych committed
44
45
46


class AdaptiveKLController(KLController):
chenych's avatar
Update  
chenych committed
47
48
49
    """Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf

    Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L54"""
chenych's avatar
chenych committed
50
51

    def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
chenych's avatar
Update  
chenych committed
52
        self.kl_coef = init_kl_coef
chenych's avatar
chenych committed
53
54
55
        self.target = target_kl
        self.horizon = horizon

chenych's avatar
chenych committed
56
    def update(self, current_kl: float, n_steps: int) -> None:
chenych's avatar
chenych committed
57
58
59
        target = self.target
        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
chenych's avatar
Update  
chenych committed
60
        self.kl_coef *= mult
chenych's avatar
chenych committed
61
62


chenych's avatar
chenych committed
63
class FixedKLController(KLController):
chenych's avatar
Update  
chenych committed
64
65
66
    """Fixed KL controller.

    Copeid from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L72"""
chenych's avatar
chenych committed
67

chenych's avatar
chenych committed
68
    def __init__(self, init_kl_coef: float):
chenych's avatar
Update  
chenych committed
69
        self.kl_coef = init_kl_coef
chenych's avatar
chenych committed
70

chenych's avatar
chenych committed
71
    def update(self, current_kl: float, n_steps: int) -> None:
chenych's avatar
chenych committed
72
73
74
        pass


chenych's avatar
chenych committed
75
76
def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
    """Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319"""
chenych's avatar
chenych committed
77
    if algorithm_config.kl_type == "fixed":
chenych's avatar
chenych committed
78
        kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef)
chenych's avatar
chenych committed
79
80
81
82
83
84
85
86
    elif algorithm_config.kl_type == "adaptive":
        assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
        kl_ctrl = AdaptiveKLController(
            init_kl_coef=algorithm_config.kl_coef,
            target_kl=algorithm_config.kl_target,
            horizon=algorithm_config.kl_horizon,
        )
    else:
chenych's avatar
chenych committed
87
        raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.")
chenych's avatar
chenych committed
88
89
90
91

    return kl_ctrl


chenych's avatar
chenych committed
92
@torch.no_grad()
chenych's avatar
chenych committed
93
94
95
def compute_gae_advantage_return(
    token_level_rewards: torch.Tensor,
    values: torch.Tensor,
chenych's avatar
Update  
chenych committed
96
    response_mask: torch.Tensor,
chenych's avatar
chenych committed
97
98
    gamma: torch.Tensor,
    lam: torch.Tensor,
chenych's avatar
chenych committed
99
100
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513
chenych's avatar
chenych committed
101
102
103
104
105
106

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        values: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
107
108
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length). The token after eos tokens have mask zero.
chenych's avatar
chenych committed
109
110
111
112
113
114
115
116
        gamma: `(float)`
            discounted factor used in RL
        lam: `(float)`
            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
117
        returns: `(torch.Tensor)`
chenych's avatar
chenych committed
118
119
120
            shape: (bs, response_length)

    """
chenych's avatar
chenych committed
121
122
123
124
125
126
127
128
129
130
    lastgaelam = 0
    advantages_reversed = []
    gen_len = token_level_rewards.shape[-1]
    for t in reversed(range(gen_len)):
        nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
        delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgaelam = delta + gamma * lam * lastgaelam
        advantages_reversed.append(lastgaelam)

    advantages = torch.stack(advantages_reversed[::-1], dim=1)
chenych's avatar
Update  
chenych committed
131
132
    returns = advantages + values
    advantages = VF.masked_whiten(advantages, response_mask)
chenych's avatar
chenych committed
133
134
135
136
    return advantages, returns


# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
chenych's avatar
chenych committed
137
@torch.no_grad()
chenych's avatar
chenych committed
138
def compute_grpo_outcome_advantage(
chenych's avatar
Update  
chenych committed
139
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
chenych's avatar
chenych committed
140
) -> Tuple[torch.Tensor, torch.Tensor]:
chenych's avatar
chenych committed
141
142
143
    """
    Compute advantage for GRPO, operating only on Outcome reward
    (with only one scalar reward for each response).
chenych's avatar
Update  
chenych committed
144

chenych's avatar
chenych committed
145
146
147
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
148
        response_mask: `(torch.Tensor)`
chenych's avatar
chenych committed
149
150
            shape: (bs, response_length)

chenych's avatar
chenych committed
151
152
153
    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
154
        returns: `(torch.Tensor)`
chenych's avatar
chenych committed
155
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
156

chenych's avatar
chenych committed
157
158
159
160
161
162
163
164
165
166
    """
    scores = token_level_rewards.sum(dim=-1)
    id2score = defaultdict(list)
    id2mean, id2std = {}, {}

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
chenych's avatar
Update  
chenych committed
167
168
169
        assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
        id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
        id2std[idx] = torch.std(torch.tensor(id2score[idx]))
chenych's avatar
chenych committed
170
171
172
173

    for i in range(bsz):
        scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)

chenych's avatar
Update  
chenych committed
174
175
    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns
chenych's avatar
chenych committed
176
177
178
179


@torch.no_grad()
def compute_rloo_outcome_advantage(
chenych's avatar
Update  
chenych committed
180
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor
chenych's avatar
chenych committed
181
182
183
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
chenych's avatar
Update  
chenych committed
184

chenych's avatar
chenych committed
185
186
187
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
188
        response_mask: `(torch.Tensor)`
chenych's avatar
chenych committed
189
190
            shape: (bs, response_length)

chenych's avatar
chenych committed
191
192
193
    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
194
        returns: `(torch.Tensor)`
chenych's avatar
chenych committed
195
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
196

chenych's avatar
chenych committed
197
198
199
200
    """
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
chenych's avatar
Update  
chenych committed
201
    id2sum = {}
chenych's avatar
chenych committed
202
203
204
205
206
    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
chenych's avatar
Update  
chenych committed
207
        id2sum[idx] = torch.sum(torch.tensor(id2score[idx]))
chenych's avatar
chenych committed
208
209

    for i in range(bsz):
chenych's avatar
Update  
chenych committed
210
211
212
213
        sample_num = len(id2score[index[i]])
        assert sample_num > 1, "RLOO needs rollout.n > 1."
        baseline = (id2sum[index[i]] - scores[i]) / (sample_num - 1)
        scores[i] = scores[i] - baseline
chenych's avatar
chenych committed
214

chenych's avatar
Update  
chenych committed
215
216
    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns
chenych's avatar
chenych committed
217
218


chenych's avatar
chenych committed
219
@torch.no_grad()
chenych's avatar
chenych committed
220
def compute_reinforce_plus_plus_outcome_advantage(
chenych's avatar
Update  
chenych committed
221
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor
chenych's avatar
chenych committed
222
) -> Tuple[torch.Tensor, torch.Tensor]:
chenych's avatar
chenych committed
223
224
225
    """
    Compute advantage for REINFORCE++.
    This implementation is based on the paper: https://arxiv.org/abs/2501.03262
chenych's avatar
Update  
chenych committed
226

chenych's avatar
chenych committed
227
228
229
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
230
        response_mask: `(torch.Tensor)`
chenych's avatar
chenych committed
231
232
233
234
235
            shape: (bs, response_length)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
236
        returns: `(torch.Tensor)`
chenych's avatar
chenych committed
237
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
238

chenych's avatar
chenych committed
239
    """
chenych's avatar
chenych committed
240
241
242
243
244
245
    returns = torch.zeros_like(token_level_rewards)
    running_return = 0
    for t in reversed(range(token_level_rewards.shape[1])):
        running_return = token_level_rewards[:, t] + gamma * running_return
        returns[:, t] = running_return
        # Reset after EOS
chenych's avatar
Update  
chenych committed
246
        running_return = running_return * response_mask[:, t]
chenych's avatar
chenych committed
247

chenych's avatar
Update  
chenych committed
248
    advantages = VF.masked_whiten(returns, response_mask)
chenych's avatar
chenych committed
249
250
251
    return advantages, returns


chenych's avatar
chenych committed
252
@torch.no_grad()
chenych's avatar
chenych committed
253
def compute_remax_outcome_advantage(
chenych's avatar
Update  
chenych committed
254
    token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor
chenych's avatar
chenych committed
255
) -> Tuple[torch.Tensor, torch.Tensor]:
chenych's avatar
chenych committed
256
257
258
259
260
261
262
263
264
265
    """
    Compute advantage for ReMax, operating only on Outcome reward
    This implementation is based on the paper: https://arxiv.org/abs/2310.10505

    (with only one scalar reward for each response).
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        reward_baselines: `(torch.Tensor)`
            shape: (bs,)
chenych's avatar
Update  
chenych committed
266
        response_mask: `(torch.Tensor)`
chenych's avatar
chenych committed
267
268
269
270
271
            shape: (bs, response_length)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
272
        returns: `(torch.Tensor)`
chenych's avatar
chenych committed
273
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
274

chenych's avatar
chenych committed
275
    """
chenych's avatar
Update  
chenych committed
276
277
278
    scores = token_level_rewards.sum(dim=-1) - reward_baselines
    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns
chenych's avatar
chenych committed
279
280


chenych's avatar
chenych committed
281
282
283
284
285
286
287
def compute_rewards(
    token_level_scores: torch.Tensor,
    log_probs: torch.Tensor,
    ref_log_probs: torch.Tensor,
    kl_ratio: float,
) -> torch.Tensor:
    kl = log_probs - ref_log_probs
chenych's avatar
chenych committed
288
289
290
291
    return token_level_scores - kl * kl_ratio


def compute_policy_loss(
chenych's avatar
chenych committed
292
293
294
    old_log_probs: torch.Tensor,
    log_probs: torch.Tensor,
    advantages: torch.Tensor,
chenych's avatar
Update  
chenych committed
295
296
297
298
299
    response_mask: torch.Tensor,
    clip_ratio_low: float,
    clip_ratio_high: float,
    clip_ratio_dual: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
chenych's avatar
chenych committed
300
301
302
    """Compute the policy loss.

    Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
chenych's avatar
chenych committed
303
304
305
306
307
308
309
310

    Args:
        old_log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
311
        response_mask: `(torch.Tensor)`
chenych's avatar
chenych committed
312
            shape: (bs, response_length)
chenych's avatar
Update  
chenych committed
313
314
315
316
317
318
        clip_ratio_low: (float)
            The lower clip range used in PPO. See https://arxiv.org/abs/1707.06347
        clip_ratio_high: (float)
            The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476
        clip_ratio_dual: (float)
            The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729
chenych's avatar
chenych committed
319
320
321
322

    Returns:
        pg_loss: `a scalar torch.Tensor`
            policy gradient loss computed via PPO
chenych's avatar
Update  
chenych committed
323
324
325
326
327
328
329
        pg_clipfrac_higher: (float)
            a float number indicating the fraction of policy gradient loss being clipped to a higher value
        pg_clipfrac_lower: (float)
            a float number indicating the fraction of policy gradient loss being clipped to a lower value
        ppo_kl: (float)
            a float number indicating the mean KL divergence between the old policy and the new policy

chenych's avatar
chenych committed
330
    """
chenych's avatar
chenych committed
331
332
333
    negative_approx_kl = log_probs - old_log_probs
    # clamp the ratio before exp to avoid nan
    # see: https://github.com/pytorch/pytorch/issues/10729
chenych's avatar
chenych committed
334
    ratio = torch.exp(negative_approx_kl)
chenych's avatar
Update  
chenych committed
335
336
337
    clipped_ratio = torch.exp(
        torch.clamp(negative_approx_kl, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high))
    )
chenych's avatar
chenych committed
338

chenych's avatar
Update  
chenych committed
339
340
341
    pg_loss = -advantages * ratio
    pg_loss2 = -advantages * clipped_ratio
    pg_loss3 = -advantages * clip_ratio_dual
chenych's avatar
chenych committed
342

chenych's avatar
Update  
chenych committed
343
344
345
346
347
348
349
350
351
352
353
    clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2)  # clip if pg_loss < pg_loss2
    pg_clipfrac_higher = (pg_loss < pg_loss2).float()
    clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3)  # clip if pg_loss > pg_loss3 and adv < 0
    final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher)
    pg_clipfrac_lower = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float()

    final_pg_loss = VF.masked_mean(final_pg_loss, response_mask)
    pg_clipfrac_higher = VF.masked_mean(pg_clipfrac_higher, response_mask)
    pg_clipfrac_lower = VF.masked_mean(pg_clipfrac_lower, response_mask)
    ppo_kl = VF.masked_mean(-negative_approx_kl, response_mask)
    return final_pg_loss, pg_clipfrac_higher, pg_clipfrac_lower, ppo_kl
chenych's avatar
chenych committed
354
355


chenych's avatar
chenych committed
356
357
358
359
def compute_value_loss(
    vpreds: torch.Tensor,
    returns: torch.Tensor,
    values: torch.Tensor,
chenych's avatar
Update  
chenych committed
360
    action_mask: torch.Tensor,
chenych's avatar
chenych committed
361
362
363
    cliprange_value: float,
) -> Tuple[torch.Tensor, float]:
    """Compute the value loss.
chenych's avatar
chenych committed
364

chenych's avatar
Update  
chenych committed
365
    Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
chenych's avatar
chenych committed
366
367
368
369
370
371

    Args:
        vpreds (`torch.FloatTensor`):
            Predicted values of the value head, shape (`batch_size`, `response_length`)
        returns: (`torch.FloatTensor`):
            Ground truth returns, shape (`batch_size`, `response_length`)
chenych's avatar
chenych committed
372
373
        values (`torch.FloatTensor`):
            Old values of value head, shape (`batch_size`, `response_length`)
chenych's avatar
Update  
chenych committed
374
        action_mask: `(torch.Tensor)`
chenych's avatar
chenych committed
375
376
377
            shape: (bs, response_length)
        cliprange_value: (float)
            The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347
chenych's avatar
chenych committed
378
379
380
381
382
383

    Returns:
        vf_loss: a scalar (`torch.FloatTensor`):
            value function loss
        vf_clipfrac: a float
            The ratio of vf being clipped
chenych's avatar
Update  
chenych committed
384

chenych's avatar
chenych committed
385
    """
chenych's avatar
chenych committed
386
    vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
chenych's avatar
Update  
chenych committed
387
388
389
390
    vf_loss1 = torch.square(vpreds - returns)
    vf_loss2 = torch.square(vpredclipped - returns)
    vf_loss = 0.5 * VF.masked_mean(torch.max(vf_loss1, vf_loss2), action_mask)  # clip if vf_loss1 < vf_loss2
    vf_clipfrac = VF.masked_mean((vf_loss1 < vf_loss2).float(), action_mask)
chenych's avatar
chenych committed
391
392
393
    return vf_loss, vf_clipfrac


chenych's avatar
Update  
chenych committed
394
def compute_kl(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
chenych's avatar
chenych committed
395
    """Compute KL divergence given log_probs and ref_log_probs.
chenych's avatar
Update  
chenych committed
396
397

    Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
chenych's avatar
chenych committed
398
399

    Args:
chenych's avatar
chenych committed
400
401
        log_probs: torch.Tensor
        ref_log_probs: torch.Tensor
chenych's avatar
Update  
chenych committed
402
        kl_penalty: str
chenych's avatar
chenych committed
403
404

    Returns:
chenych's avatar
chenych committed
405
        kl_div: torch.Tensor
chenych's avatar
Update  
chenych committed
406

chenych's avatar
chenych committed
407
    """
chenych's avatar
chenych committed
408
    log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
chenych's avatar
chenych committed
409
    if kl_penalty == "kl":
chenych's avatar
chenych committed
410
        return log_probs - ref_log_probs
chenych's avatar
chenych committed
411
412

    if kl_penalty == "abs":
chenych's avatar
chenych committed
413
        return (log_probs - ref_log_probs).abs()
chenych's avatar
chenych committed
414
415

    if kl_penalty == "mse":
chenych's avatar
chenych committed
416
        return 0.5 * (log_probs - ref_log_probs).square()
chenych's avatar
chenych committed
417
418

    # J. Schulman. Approximating kl divergence, 2020.
chenych's avatar
chenych committed
419
    # URL http://joschu.net/blog/kl-approx.html
chenych's avatar
chenych committed
420
    if kl_penalty == "low_var_kl":
chenych's avatar
chenych committed
421
422
        kl = ref_log_probs - log_probs
        kld = (kl.exp() - kl - 1).contiguous()
chenych's avatar
chenych committed
423
424
425
        return torch.clamp(kld, min=-10, max=10)

    if kl_penalty == "full":
chenych's avatar
chenych committed
426
        return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)
chenych's avatar
chenych committed
427

chenych's avatar
chenych committed
428
    raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")