core_algos.py 14 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2022 The HuggingFace Team
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO
"""

chenych's avatar
chenych committed
21
from abc import ABC, abstractmethod
chenych's avatar
chenych committed
22
23
24
25
26
from collections import defaultdict
from typing import TYPE_CHECKING, Tuple

import numpy as np
import torch
chenych's avatar
chenych committed
27
import torch.nn.functional as F
chenych's avatar
chenych committed
28

chenych's avatar
chenych committed
29
from ..utils import torch_functional as VF
chenych's avatar
chenych committed
30
31
32


if TYPE_CHECKING:
chenych's avatar
chenych committed
33
    from .config import AlgorithmConfig
chenych's avatar
chenych committed
34
35


chenych's avatar
chenych committed
36
37
38
39
40
41
42
class KLController(ABC):
    @abstractmethod
    def update(self, current_kl: float, n_steps: int) -> None: ...


class AdaptiveKLController(KLController):
    """Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf"""
chenych's avatar
chenych committed
43
44
45
46
47
48

    def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
        self.value = init_kl_coef
        self.target = target_kl
        self.horizon = horizon

chenych's avatar
chenych committed
49
    def update(self, current_kl: float, n_steps: int) -> None:
chenych's avatar
chenych committed
50
51
52
53
54
55
        target = self.target
        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult


chenych's avatar
chenych committed
56
class FixedKLController(KLController):
chenych's avatar
chenych committed
57
58
    """Fixed KL controller."""

chenych's avatar
chenych committed
59
60
    def __init__(self, init_kl_coef: float):
        self.value = init_kl_coef
chenych's avatar
chenych committed
61

chenych's avatar
chenych committed
62
    def update(self, current_kl: float, n_steps: int) -> None:
chenych's avatar
chenych committed
63
64
65
        pass


chenych's avatar
chenych committed
66
67
def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
    """Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319"""
chenych's avatar
chenych committed
68
    if algorithm_config.kl_type == "fixed":
chenych's avatar
chenych committed
69
        kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef)
chenych's avatar
chenych committed
70
71
72
73
74
75
76
77
    elif algorithm_config.kl_type == "adaptive":
        assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
        kl_ctrl = AdaptiveKLController(
            init_kl_coef=algorithm_config.kl_coef,
            target_kl=algorithm_config.kl_target,
            horizon=algorithm_config.kl_horizon,
        )
    else:
chenych's avatar
chenych committed
78
        raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.")
chenych's avatar
chenych committed
79
80
81
82

    return kl_ctrl


chenych's avatar
chenych committed
83
@torch.no_grad()
chenych's avatar
chenych committed
84
85
86
87
88
89
def compute_gae_advantage_return(
    token_level_rewards: torch.Tensor,
    values: torch.Tensor,
    eos_mask: torch.Tensor,
    gamma: torch.Tensor,
    lam: torch.Tensor,
chenych's avatar
chenych committed
90
91
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513
chenych's avatar
chenych committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        values: `(torch.Tensor)`
            shape: (bs, response_length)
        eos_mask: `(torch.Tensor)`
            shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
        gamma: `(float)`
            discounted factor used in RL
        lam: `(float)`
            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)

    """
chenych's avatar
chenych committed
112
113
114
115
116
117
118
119
120
121
122
123
    lastgaelam = 0
    advantages_reversed = []
    gen_len = token_level_rewards.shape[-1]
    for t in reversed(range(gen_len)):
        nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
        delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgaelam = delta + gamma * lam * lastgaelam
        advantages_reversed.append(lastgaelam)

    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = (advantages + values) * eos_mask
    advantages = VF.masked_whiten(advantages, eos_mask) * eos_mask
chenych's avatar
chenych committed
124
125
126
127
    return advantages, returns


# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
chenych's avatar
chenych committed
128
@torch.no_grad()
chenych's avatar
chenych committed
129
def compute_grpo_outcome_advantage(
chenych's avatar
chenych committed
130
131
    token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
) -> Tuple[torch.Tensor, torch.Tensor]:
chenych's avatar
chenych committed
132
133
134
135
136
137
138
139
140
    """
    Compute advantage for GRPO, operating only on Outcome reward
    (with only one scalar reward for each response).
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        eos_mask: `(torch.Tensor)`
            shape: (bs, response_length)

chenych's avatar
chenych committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    response_length = token_level_rewards.shape[-1]
    scores = token_level_rewards.sum(dim=-1)
    id2score = defaultdict(list)
    id2mean, id2std = {}, {}

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
        if len(id2score[idx]) == 1:
            id2mean[idx] = torch.tensor(0.0)
            id2std[idx] = torch.tensor(1.0)
        elif len(id2score[idx]) > 1:
            id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
            id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
        else:
            raise ValueError(f"no score in prompt index: {idx}")

    for i in range(bsz):
        scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)

    scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
    return scores, scores


@torch.no_grad()
def compute_rloo_outcome_advantage(
    token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        eos_mask: `(torch.Tensor)`
            shape: (bs, response_length)

chenych's avatar
chenych committed
185
186
187
188
189
190
191
192
193
194
195
    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    response_length = token_level_rewards.shape[-1]
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}
chenych's avatar
chenych committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
        if len(id2score[idx]) == 1:
            id2mean[idx] = torch.tensor(0.0)
        elif len(id2score[idx]) > 1:
            id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
        else:
            raise ValueError(f"no score in prompt index: {idx}.")

    for i in range(bsz):
        response_num = len(id2score[index[i]])
        if response_num > 1:
            scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
                response_num - 1
            )

    scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
chenych's avatar
chenych committed
216
217
218
    return scores, scores


chenych's avatar
chenych committed
219
@torch.no_grad()
chenych's avatar
chenych committed
220
221
def compute_reinforce_plus_plus_outcome_advantage(
    token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor
chenych's avatar
chenych committed
222
) -> Tuple[torch.Tensor, torch.Tensor]:
chenych's avatar
chenych committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    """
    Compute advantage for REINFORCE++.
    This implementation is based on the paper: https://arxiv.org/abs/2501.03262
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        eos_mask: `(torch.Tensor)`
            shape: (bs, response_length)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
chenych's avatar
chenych committed
238
239
240
241
242
243
244
245
246
247
248
    returns = torch.zeros_like(token_level_rewards)
    running_return = 0
    for t in reversed(range(token_level_rewards.shape[1])):
        running_return = token_level_rewards[:, t] + gamma * running_return
        returns[:, t] = running_return
        # Reset after EOS
        running_return = running_return * eos_mask[:, t]

    advantages = VF.masked_whiten(returns, eos_mask)
    advantages *= eos_mask
    returns *= eos_mask
chenych's avatar
chenych committed
249
250
251
    return advantages, returns


chenych's avatar
chenych committed
252
@torch.no_grad()
chenych's avatar
chenych committed
253
254
def compute_remax_outcome_advantage(
    token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor
chenych's avatar
chenych committed
255
) -> Tuple[torch.Tensor, torch.Tensor]:
chenych's avatar
chenych committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    """
    Compute advantage for ReMax, operating only on Outcome reward
    This implementation is based on the paper: https://arxiv.org/abs/2310.10505

    (with only one scalar reward for each response).
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        reward_baselines: `(torch.Tensor)`
            shape: (bs,)
        eos_mask: `(torch.Tensor)`
            shape: (bs, response_length)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    response_length = token_level_rewards.shape[-1]
    # scores = token_level_rewards.sum(dim=-1)
chenych's avatar
chenych committed
277
278
    returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) * eos_mask
    advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
chenych's avatar
chenych committed
279
280
281
    return advantages, returns


chenych's avatar
chenych committed
282
283
284
285
286
287
288
def compute_rewards(
    token_level_scores: torch.Tensor,
    log_probs: torch.Tensor,
    ref_log_probs: torch.Tensor,
    kl_ratio: float,
) -> torch.Tensor:
    kl = log_probs - ref_log_probs
chenych's avatar
chenych committed
289
290
291
292
    return token_level_scores - kl * kl_ratio


def compute_policy_loss(
chenych's avatar
chenych committed
293
294
295
296
297
    old_log_probs: torch.Tensor,
    log_probs: torch.Tensor,
    advantages: torch.Tensor,
    eos_mask: torch.Tensor,
    cliprange: float,
chenych's avatar
chenych committed
298
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
chenych's avatar
chenych committed
299
300
301
    """Compute the policy loss.

    Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
chenych's avatar
chenych committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

    Args:
        old_log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        eos_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        cliprange: (float)
            The clip range used in PPO. See https://arxiv.org/abs/1707.06347

    Returns:
        pg_loss: `a scalar torch.Tensor`
            policy gradient loss computed via PPO
        pg_clipfrac: (float)
            a float number indicating the fraction of policy gradient loss being clipped
    """
chenych's avatar
chenych committed
321
322
323
    negative_approx_kl = log_probs - old_log_probs
    # clamp the ratio before exp to avoid nan
    # see: https://github.com/pytorch/pytorch/issues/10729
chenych's avatar
chenych committed
324
    ratio = torch.exp(negative_approx_kl)
chenych's avatar
chenych committed
325
326
    clipped_ratio = torch.exp(torch.clamp(negative_approx_kl, np.log(1.0 - cliprange), np.log(1.0 + cliprange)))
    ppo_kl = VF.masked_mean(-negative_approx_kl, eos_mask)
chenych's avatar
chenych committed
327
328

    pg_losses = -advantages * ratio
chenych's avatar
chenych committed
329
    pg_losses2 = -advantages * clipped_ratio
chenych's avatar
chenych committed
330

chenych's avatar
chenych committed
331
332
    pg_loss = VF.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
    pg_clipfrac = VF.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
chenych's avatar
chenych committed
333
334
335
    return pg_loss, pg_clipfrac, ppo_kl


chenych's avatar
chenych committed
336
337
338
339
340
341
342
343
def compute_value_loss(
    vpreds: torch.Tensor,
    returns: torch.Tensor,
    values: torch.Tensor,
    eos_mask: torch.Tensor,
    cliprange_value: float,
) -> Tuple[torch.Tensor, float]:
    """Compute the value loss.
chenych's avatar
chenych committed
344

chenych's avatar
chenych committed
345
    Copied from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
chenych's avatar
chenych committed
346
347
348
349
350
351

    Args:
        vpreds (`torch.FloatTensor`):
            Predicted values of the value head, shape (`batch_size`, `response_length`)
        returns: (`torch.FloatTensor`):
            Ground truth returns, shape (`batch_size`, `response_length`)
chenych's avatar
chenych committed
352
353
354
355
356
357
        values (`torch.FloatTensor`):
            Old values of value head, shape (`batch_size`, `response_length`)
        eos_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        cliprange_value: (float)
            The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347
chenych's avatar
chenych committed
358
359
360
361
362
363
364

    Returns:
        vf_loss: a scalar (`torch.FloatTensor`):
            value function loss
        vf_clipfrac: a float
            The ratio of vf being clipped
    """
chenych's avatar
chenych committed
365
366
367
368
369
    vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
    vf_losses1 = torch.square(vpreds - returns)
    vf_losses2 = torch.square(vpredclipped - returns)
    vf_loss = 0.5 * VF.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
    vf_clipfrac = VF.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
chenych's avatar
chenych committed
370
371
372
    return vf_loss, vf_clipfrac


chenych's avatar
chenych committed
373
374
375
def kl_penalty(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
    """Compute KL divergence given log_probs and ref_log_probs.
    Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
chenych's avatar
chenych committed
376
377

    Args:
chenych's avatar
chenych committed
378
379
        log_probs: torch.Tensor
        ref_log_probs: torch.Tensor
chenych's avatar
chenych committed
380
381

    Returns:
chenych's avatar
chenych committed
382
        kl_div: torch.Tensor
chenych's avatar
chenych committed
383
    """
chenych's avatar
chenych committed
384
    log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
chenych's avatar
chenych committed
385
    if kl_penalty == "kl":
chenych's avatar
chenych committed
386
        return log_probs - ref_log_probs
chenych's avatar
chenych committed
387
388

    if kl_penalty == "abs":
chenych's avatar
chenych committed
389
        return (log_probs - ref_log_probs).abs()
chenych's avatar
chenych committed
390
391

    if kl_penalty == "mse":
chenych's avatar
chenych committed
392
        return 0.5 * (log_probs - ref_log_probs).square()
chenych's avatar
chenych committed
393
394

    # J. Schulman. Approximating kl divergence, 2020.
chenych's avatar
chenych committed
395
    # URL http://joschu.net/blog/kl-approx.html
chenych's avatar
chenych committed
396
    if kl_penalty == "low_var_kl":
chenych's avatar
chenych committed
397
398
        kl = ref_log_probs - log_probs
        kld = (kl.exp() - kl - 1).contiguous()
chenych's avatar
chenych committed
399
400
401
        return torch.clamp(kld, min=-10, max=10)

    if kl_penalty == "full":
chenych's avatar
chenych committed
402
        return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)
chenych's avatar
chenych committed
403

chenych's avatar
chenych committed
404
    raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")