dp_actor.py 14.4 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import itertools
import math
from collections import defaultdict

import numpy as np
import torch

from recipe.spin.core_algos import compute_online_dpo_loss, get_batch_logps
from verl import DataProto
from verl.utils.device import get_device_name
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.workers.actor import DataParallelPPOActor

__all__ = ["DataParallelPPOActor"]


class SPINDataParallelPPOActor(DataParallelPPOActor):
    def compute_log_prob(self, data: DataProto) -> torch.Tensor:
        """Compute the log probability of the responses given input_ids, attention_mask and position_ids

        Args:
            data (DataProto): a DataProto containing keys

                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.

                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.

                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.

                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.

        Returns:
            torch.Tensor: the log_prob tensor
        """
        # set to eval
        self.actor_module.eval()

        micro_batch_size = data.meta_info["micro_batch_size"]
        temperature = data.meta_info["temperature"]  # temperature must be in the data.meta_info to avoid silent error
        use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]

        select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
        batch = data.select(batch_keys=select_keys).batch
        has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()

        if has_multi_modal_inputs:
            num_micro_batches = data.batch.batch_size[0] // micro_batch_size
            non_tensor_select_keys = ["multi_modal_inputs"]
            micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
        elif use_dynamic_bsz:
            # split using dynamic bsz
            max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
        else:
            micro_batches = batch.split(micro_batch_size)

        log_probs_lst = []
        for micro_batch in micro_batches:
            if isinstance(micro_batch, DataProto):
                micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}

            with torch.no_grad():
                _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
            log_probs_lst.append(log_probs)
        log_probs = torch.concat(log_probs_lst, dim=0)

        if use_dynamic_bsz:
            indices = list(itertools.chain.from_iterable(indices))
            assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
            log_probs = log_probs[revert_indices]

        return log_probs

    def update_policy_dpo_with_ref(self, data: DataProto):
        """
        Performs the DPO update step using pre-calculated reference log probs
        from an external, periodically updated reference model.
        """
        self.actor_module.train()  # Ensure training mode

        # --- Retrieve necessary data ---
        try:
            # Expects batch prepared by fit_dpo loop, including reference log probs
            batch_td = data.batch
            chosen_labels = batch_td["chosen_labels"]
            rejected_labels = batch_td["rejected_labels"]
            # ... other needed tensors like chosen/rejected input_ids, attention_mask, position_ids ...

            # === Get PRE-CALCULATED reference log probs from input data ===
            reference_chosen_logps = batch_td["reference_chosen_logps"]  # Should be sequence-level logps
            reference_rejected_logps = batch_td["reference_rejected_logps"]  # Should be sequence-level logps
            # ============================================================

            # Get DPO params from meta_info
            # beta = data.meta_info.get('dpo_beta', 0.1) # Default beta
            beta = self.config.get("dpo_beta", 0.1)  # Default beta
            loss_type = data.meta_info.get("dpo_loss_type", "sigmoid")
            label_smoothing = data.meta_info.get("dpo_label_smoothing", 0.0)
            # reference_free should now be False as we provide ref logps
            reference_free = data.meta_info.get("reference_free", False)  # Default False

        except KeyError as e:
            print(f"ERROR: Missing required key for DPO update (in update_policy_dpo): {e}")
            print(f"Available keys in data.batch: {list(batch_td.keys())}")  # Debug print
            return {}  # Return empty metrics on error
        except Exception as e_data:
            print(f"ERROR accessing data for DPO update (in update_policy_dpo): {e_data}")
            return {}

        # --- Micro-batching Setup ---
        micro_batch_size = self.config.get("ppo_micro_batch_size_per_gpu")
        if micro_batch_size is None:
            # Fallback or default if not set, or raise error
            micro_batch_size = 1  # Example fallback, adjust as needed
            print(f"Warning: 'ppo_micro_batch_size_per_gpu' not set, defaulting to {micro_batch_size}")
            # raise ValueError("Config 'ppo_micro_batch_size_per_gpu' must be set.")

        # Ensure chosen_input_ids exists before getting shape
        if "chosen_input_ids" not in batch_td:
            print("ERROR: 'chosen_input_ids' not found in batch_td for DPO update.")
            return {}
        bsz = batch_td["chosen_input_ids"].shape[0]

        if bsz == 0:
            print("Warning: DPO batch size is 0 in update_policy_dpo. Skipping update.")
            return {"actor/dpo_loss": 0.0, "actor/grad_norm": 0.0}  # Return zero metrics if batch is empty

        num_micro_batches = math.ceil(bsz / micro_batch_size)
        gradient_accumulation_steps = num_micro_batches

        # --- Metrics Accumulation ---
        total_loss = 0.0
        accumulated_metrics = defaultdict(list)
        metrics = {}  # Final metrics dict

        # --- Zero Gradients ---
        self.actor_optimizer.zero_grad(set_to_none=True)

        # --- Micro-batch Loop ---
        for i in range(num_micro_batches):
            start_idx = i * micro_batch_size
            end_idx = min(start_idx + micro_batch_size, bsz)
            if start_idx >= end_idx:
                continue

            # Slice the full DPO batch into micro-batches
            # Important: Slice ALL required tensors, including labels and inputs
            micro_batch_chosen_labels = chosen_labels[start_idx:end_idx]
            micro_batch_rejected_labels = rejected_labels[start_idx:end_idx]
            micro_batch_chosen_inputs = {
                "input_ids": batch_td["chosen_input_ids"][start_idx:end_idx],
                "attention_mask": batch_td["chosen_attention_mask"][start_idx:end_idx],
            }
            if "chosen_position_ids" in batch_td:
                micro_batch_chosen_inputs["position_ids"] = batch_td["chosen_position_ids"][start_idx:end_idx]

            micro_batch_rejected_inputs = {
                "input_ids": batch_td["rejected_input_ids"][start_idx:end_idx],
                "attention_mask": batch_td["rejected_attention_mask"][start_idx:end_idx],
            }
            if "rejected_position_ids" in batch_td:
                micro_batch_rejected_inputs["position_ids"] = batch_td["rejected_position_ids"][start_idx:end_idx]

            # Determine autocast dtype
            autocast_dtype = torch.bfloat16  # Or get dynamically from config/FSDP settings
            # --- Autocast Forward Pass ---
            with torch.autocast(device_type=get_device_name(), dtype=autocast_dtype):
                # --- Step 1: Forward pass for CURRENT policy log probs (with grad) ---
                policy_chosen_outputs = self.actor_module(**micro_batch_chosen_inputs, use_cache=False)
                policy_rejected_outputs = self.actor_module(**micro_batch_rejected_inputs, use_cache=False)

                # --- Step 2: Calculate CURRENT policy log probs using get_batch_logps ---
                policy_chosen_logps = get_batch_logps(
                    policy_chosen_outputs.logits, micro_batch_chosen_labels, average_log_prob=False
                )
                policy_rejected_logps = get_batch_logps(
                    policy_rejected_outputs.logits, micro_batch_rejected_labels, average_log_prob=False
                )

                # --- Step 3: Retrieve PRE-CALCULATED reference log probs (NO grad needed) ---
                # Slice the full batch reference logps for the current micro-batch
                micro_ref_chosen_logps = reference_chosen_logps[start_idx:end_idx]
                micro_ref_rejected_logps = reference_rejected_logps[start_idx:end_idx]
                # --- The ActorAsRef calculation block is REMOVED ---

                # --- Step 4: Calculate DPO Logits and Loss ---
                pi_logratios = policy_chosen_logps - policy_rejected_logps
                ref_logratios = micro_ref_chosen_logps - micro_ref_rejected_logps  # Uses pre-calculated values
                logits = pi_logratios - ref_logratios  # DPO logits

                loss = compute_online_dpo_loss(
                    policy_chosen_logps=policy_chosen_logps,  # Has grad
                    policy_rejected_logps=policy_rejected_logps,  # Has grad
                    reference_chosen_logps=micro_ref_chosen_logps,  # No grad (from input)
                    reference_rejected_logps=micro_ref_rejected_logps,  # No grad (from input)
                    beta=beta,
                    label_smoothing=label_smoothing,
                    loss_type=loss_type,
                    reference_free=reference_free,  # Should be False now
                )

                # --- Scale loss for gradient accumulation ---
                scaled_loss = loss / gradient_accumulation_steps

                # --- Accumulate Metrics ---
                total_loss += loss.item()  # Unscaled loss
                accumulated_metrics["actor/dpo_loss_batch"].append(loss.item())
                accumulated_metrics["actor/dpo_logits_batch"].append(logits.mean().item())
                # Accumulate policy and reference log probs/ratios if needed for debugging
                accumulated_metrics["actor/policy_chosen_logps_batch"].append(policy_chosen_logps.mean().item())
                accumulated_metrics["actor/policy_rejected_logps_batch"].append(policy_rejected_logps.mean().item())
                accumulated_metrics["actor/reference_chosen_logps_batch"].append(micro_ref_chosen_logps.mean().item())
                accumulated_metrics["actor/reference_rejected_logps_batch"].append(
                    micro_ref_rejected_logps.mean().item()
                )

            # --- Backward Pass (outside autocast) ---
            # Check if loss requires grad before backward
            if scaled_loss.requires_grad:
                scaled_loss.backward()
            else:
                print(f"Warning: Scaled loss at micro-batch {i} does not require grad. Skipping backward.")

        # --- End Micro-batch Loop ---

        # --- Optimizer Step (after accumulating gradients for all micro-batches) ---
        grad_norm = self._optimizer_step()

        # --- Populate Final Metrics ---
        if num_micro_batches > 0 and bsz > 0:  # Check if any processing happened
            metrics["actor/dpo_loss"] = total_loss / num_micro_batches
            metrics["actor/grad_norm"] = (
                grad_norm.item() if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) else float("inf")
            )
            # Average other accumulated metrics
            for key, val_list in accumulated_metrics.items():
                if val_list:
                    metrics[key.replace("_batch", "")] = np.mean(val_list)

            # Calculate accuracy / rewards / margins based on averaged logprobs if desired
            if (
                "actor/policy_chosen_logps" in metrics
                and "actor/policy_rejected_logps" in metrics
                and "actor/reference_chosen_logps" in metrics
                and "actor/reference_rejected_logps" in metrics
            ):
                policy_ratio_mean = metrics["actor/policy_chosen_logps"] - metrics["actor/policy_rejected_logps"]
                ref_ratio_mean = metrics["actor/reference_chosen_logps"] - metrics["actor/reference_rejected_logps"]
                logits_mean = policy_ratio_mean - ref_ratio_mean
                metrics["actor/rewards_chosen"] = beta * (
                    metrics["actor/policy_chosen_logps"] - metrics["actor/reference_chosen_logps"]
                )
                metrics["actor/rewards_rejected"] = beta * (
                    metrics["actor/policy_rejected_logps"] - metrics["actor/reference_rejected_logps"]
                )
                metrics["actor/rewards_accuracies"] = float(logits_mean > 0)  # Mean accuracy proxy
                metrics["actor/rewards_margins"] = metrics["actor/rewards_chosen"] - metrics["actor/rewards_rejected"]

        else:  # Handle case where no micro-batches were run (e.g., bsz=0)
            metrics["actor/dpo_loss"] = 0.0
            metrics["actor/grad_norm"] = 0.0
            # Initialize other metrics to 0 or NaN as appropriate
            for key in accumulated_metrics.keys():
                metrics[key.replace("_batch", "")] = 0.0
            metrics["actor/rewards_chosen"] = 0.0
            metrics["actor/rewards_rejected"] = 0.0
            metrics["actor/rewards_accuracies"] = 0.0
            metrics["actor/rewards_margins"] = 0.0

        return metrics  # Return aggregated metrics