prime_dp_rm.py 17.3 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement a multiprocess PPOCritic
"""
import itertools
from typing import Iterable

import torch
import torch.distributed
from torch import nn, optim

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.workers.critic import BasePPOCritic
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
import verl.utils.torch_functional as verl_F

from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis

__all__ = ['DataParallelPRIMERewardModel']


class DataParallelPRIMERewardModel:

    def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer):
        self.config = config
        self.reward_module = reward_module
        self.ref_module = ref_module
        self.reward_optimizer = reward_optimizer
        self.use_remove_padding = self.config.model.get('use_remove_padding', False)
        print(f'Reward model use_remove_padding={self.use_remove_padding}')

        self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)

    def _forward_micro_batch(self, micro_batch, prompt_length):
        from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange
        from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad

        input_ids = micro_batch['input_ids']
        batch_size, seqlen = input_ids.shape
        attention_mask = micro_batch['attention_mask']
        position_ids = micro_batch['position_ids']

        num_actions = micro_batch['input_ids'].shape[-1] - prompt_length
        max_positions = micro_batch['attention_mask'][:, prompt_length:].sum(-1)

        if self.use_remove_padding:
            input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
                                                       attention_mask)  # input_ids_rmpad (total_nnz, ...)
            input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)

            # unpad the position_ids to align the rotary
            position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
                                                  indices).transpose(0, 1)

            # for compute the log_prob
            input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)

            # pad and slice the inputs if sp > 1
            if self.ulysses_sequence_parallel_size > 1:
                input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
                    input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size)
                input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None,
                                                                            self.ulysses_sequence_parallel_size)
            input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)
            rm_output_logits = self.reward_module(input_ids=input_ids_rmpad,
                                                  attention_mask=None,
                                                  position_ids=position_ids_rmpad,
                                                  use_cache=False).logits.squeeze(
                                                      0)  # copied. I don't really know why there is a squeeze
            rm_log_labels = verl_F.logprobs_from_logits(logits=rm_output_logits, labels=input_ids_rmpad_rolled)
            if self.ulysses_sequence_parallel_size > 1:
                rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size)
            rm_log_labels = pad_input(hidden_states=rm_log_labels.unsqueeze(-1),
                                      indices=indices,
                                      batch=batch_size,
                                      seqlen=seqlen).squeeze(-1)[:, -num_actions - 1:-1]

        else:
            rm_output_logits = self.reward_module(input_ids=micro_batch['input_ids'],
                                                  attention_mask=micro_batch['attention_mask'],
                                                  position_ids=micro_batch['position_ids'],
                                                  use_cache=False).logits
            rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :],
                                                          dim=-1)  # (batch_size, seq_length, vocab_size)
            rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze(
                -1)  # (batch, seq_length)

        if self.ref_module is not None:
            # do not have to pad again
            with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding:
                    ref_output_logits = self.ref_module(input_ids=input_ids_rmpad,
                                                        attention_mask=None,
                                                        position_ids=position_ids_rmpad,
                                                        use_cache=False).logits.squeeze(0)
                    ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits,
                                                                 labels=input_ids_rmpad_rolled)
                    ref_log_labels = gather_outpus_and_unpad(ref_log_labels,
                                                             gather_dim=0,
                                                             unpad_dim=0,
                                                             padding_size=pad_size)
                    ref_log_labels = pad_input(hidden_states=ref_log_labels.unsqueeze(-1),
                                               indices=indices,
                                               batch=batch_size,
                                               seqlen=seqlen).squeeze(-1)[:, -num_actions - 1:-1]
                else:
                    ref_output_logits = self.ref_module(input_ids=micro_batch['input_ids'],
                                                        attention_mask=micro_batch['attention_mask'],
                                                        position_ids=micro_batch['position_ids'],
                                                        use_cache=False).logits
                    ref_log_prob = torch.nn.functional.log_softmax(ref_output_logits[:, :-1, :],
                                                                   dim=-1)  # (batch_size, seq_length, vocab_size)
                    ref_log_labels = ref_log_prob.gather(dim=-1,
                                                         index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze(
                                                             -1)  # (batch, seq_length)
        else:
            ref_log_labels = micro_batch['old_log_probs']

        ref_log_labels.to(rm_log_labels.dtype)
        q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:]  # this is actually diff of q

        # trim unnecessary logprobs here
        for i in range(micro_batch['input_ids'].shape[0]):
            q[i, max_positions[i]:] = 0

        # reward computation does not need gradient. only q needs
        with torch.no_grad():

            # generalized estimation of r should go before the reward filling. r means process reward for policy model, or the advantage of reward model.
            lam = self.config.get('lambda', 0.)
            beta = self.config.model.get('beta_train', 0.05)
            if lam == 0.:
                r = q * beta
            else:
                # reward coefficient takes no effect here
                acc = micro_batch['acc']
                q_ = q * beta
                r = torch.zeros_like(q)
                lastgaelam = 0
                # change the last token and mask out all paddings to make this process easier if we rely on outcome reward to calculate V
                for i in range(q.shape[0]):
                    if self.config.prime_use_gt:
                        q_[i, max_positions[i] - 1] = acc[i] - q_[i, :max_positions[i] - 1].sum()
                    q_[i, max_positions[i]:] = 0

                for t in reversed(range(num_actions)):
                    delta = q_[:, t]
                    lastgaelam = delta + lam * lastgaelam
                    r[:, t] = lastgaelam

            token_level_score = torch.zeros_like(q)

            if self.config.prime_granularity == 'token':
                for i in range(micro_batch['input_ids'].shape[0]):
                    token_level_score[i, :max_positions[i] - 1] = r[i, :max_positions[i] - 1]
            elif self.config.prime_granularity == 'whole':
                for i in range(micro_batch['input_ids'].shape[0]):
                    token_level_score[i, max_positions[i] - 1] = r[i, :max_positions[i]]
            else:
                raise NotImplementedError

        return token_level_score, q

    def _optimizer_step(self):
        assert self.config.model.optim.grad_clip is not None

        if isinstance(self.reward_module, FSDP):
            grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip)
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(),
                                                       max_norm=self.config.model.optim.grad_clip)
        self.reward_optimizer.step()
        return grad_norm

    def prime_norm(self, token_level_scores):
        if self.config.prime_norm == 'batch_norm':
            reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1])
            token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)
        return token_level_scores

    def compute_rm_score(self, data: DataProto):
        self.reward_module.eval()
        self.ref_module.eval()
        micro_batch_size = data.meta_info['micro_batch_size']
        select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'acc']
        batch = data.select(batch_keys=select_keys).batch
        use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
        prompt_length = data.batch['input_ids'].shape[-1] - data.batch['responses'].shape[-1]

        if use_dynamic_bsz:
            # split using dynamic bsz
            max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
        else:
            micro_batches = batch.split(micro_batch_size)

        rm_scores_lst = []
        q_lst = []
        for micro_batch in micro_batches:
            with torch.no_grad():
                rm_score, q = self._forward_micro_batch(micro_batch, prompt_length)
            rm_scores_lst.append(rm_score)
            q_lst.append(q)
        rm_scores = torch.concat(rm_scores_lst, dim=0)
        q = torch.concat(q_lst, dim=0)

        rm_scores = self.prime_norm(rm_scores)

        if use_dynamic_bsz:
            indices = list(itertools.chain.from_iterable(indices))
            assert len(indices) == rm_scores.size(0), f"{len(indices)} vs. {rm_scores.size()}"
            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
            rm_scores = rm_scores[revert_indices]

        return rm_scores, q.detach(), {
            'reward_model/reward': rm_scores.sum(dim=-1).mean().item(),
            'reward_model/raw_reward': q.sum(dim=-1).mean().item()
        }

    def update_rm(self, data: DataProto):
        # make sure we are in training mode
        self.reward_module.train()
        metrics = {}

        beta = self.config.model.get('beta_train', 0.05)

        select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'acc', 'prompts']

        for key in ['Q_bc', 'acc_bc']:
            if key in data.batch.keys():
                select_keys.append(key)

        batch = data.select(batch_keys=select_keys).batch
        # Split to make minibatch iterator for updating the actor
        # See PPO paper for details. https://arxiv.org/abs/1707.06347
        dataloader = batch.split(self.config.mini_batch_size)

        rm_scores_lst = []
        q_lst = []

        for batch_idx, data in enumerate(dataloader):
            # split batch into micro_batches
            mini_batch = data
            if self.config.use_dynamic_bsz:
                max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
                micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
            else:
                micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu)
                self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu

            self.reward_optimizer.zero_grad()

            for data in micro_batches:
                data = data.cuda()
                attention_mask = data['attention_mask']
                acc = data['acc']

                prompt_ids = data['prompts']
                prompt_length = prompt_ids.shape[-1]

                response_mask = attention_mask[:, prompt_length:]

                rm_score, q = self._forward_micro_batch(data, prompt_length)

                rm_scores_lst.append(rm_score)
                q_lst.append(q.detach())

                if self.config.model.loss_type == 'ce':
                    dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta)
                elif self.config.model.loss_type == 'dpo':
                    # the implementation of dpo is actually detached, which means we have to know the average value of w/l reward before the update.
                    dpo_loss = compute_detach_dpo_loss_rm(q,
                                                          acc,
                                                          Q_bc=data['Q_bc'],
                                                          acc_bc=data['acc_bc'],
                                                          response_mask=response_mask,
                                                          beta=beta)
                elif self.config.model.loss_type == 'bon_acc':
                    # change the original distribution of each sample to BoN distribution, then update reward model
                    dpo_loss = compute_detach_dpo_loss_rm(q,
                                                          acc,
                                                          Q_bc=data['Q_bc'],
                                                          acc_bc=data['acc_bc'],
                                                          response_mask=response_mask,
                                                          beta=beta,
                                                          bon_mode='bon_acc')
                elif self.config.model.loss_type == 'bon_rm':
                    dpo_loss = compute_detach_dpo_loss_rm(q,
                                                          acc,
                                                          Q_bc=data['Q_bc'],
                                                          acc_bc=data['acc_bc'],
                                                          response_mask=response_mask,
                                                          beta=beta,
                                                          bon_mode='bon_rm')
                else:
                    raise NotImplementedError

                data = {'reward_model/dpo_loss': dpo_loss.detach().item()}

                if self.config.use_dynamic_bsz:
                    # relative to the dynamic bsz
                    loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size)
                else:
                    loss = dpo_loss / self.gradient_accumulation

                loss.backward()

                append_to_dict(metrics, data)

            grad_norm = self._optimizer_step()
            data = {'reward_model/grad_norm': grad_norm.detach().item()}
            append_to_dict(metrics, data)
        self.reward_optimizer.zero_grad()

        rm_scores = torch.cat(rm_scores_lst, dim=0)
        q = torch.concat(q_lst, dim=0)

        rm_scores = self.prime_norm(rm_scores)

        metrics.update({
            'reward_model/reward': rm_scores.sum(dim=-1).mean().item(),
            'reward_model/raw_reward': q.sum(dim=-1).mean().item()
        })

        return rm_scores, metrics