trainer.py 15.1 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
chenych's avatar
chenych committed
22
from typing import TYPE_CHECKING, Literal, Optional, Union
chenych's avatar
chenych committed
23
24
25
26
27
28

import torch
import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
shihm's avatar
uodata  
shihm committed
29
from trl.trainer.utils import prepare_deepspeed
luopl's avatar
luopl committed
30
from typing_extensions import override
chenych's avatar
chenych committed
31
32

from ...extras.constants import IGNORE_INDEX
chenych's avatar
chenych committed
33
from ...extras.packages import is_transformers_version_greater_than
luopl's avatar
luopl committed
34
35
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
chenych's avatar
chenych committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


if TYPE_CHECKING:
    from transformers import PreTrainedModel, ProcessorMixin

    from ...hparams import FinetuningArguments


class CustomDPOTrainer(DPOTrainer):
    def __init__(
        self,
        model: Union["PreTrainedModel", torch.nn.Module],
        ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
        finetuning_args: "FinetuningArguments",
        processor: Optional["ProcessorMixin"],
        disable_dropout: bool = True,
        **kwargs,
    ):
luopl's avatar
luopl committed
54
55
56
        if is_transformers_version_greater_than("4.46"):
            kwargs["processing_class"] = kwargs.pop("tokenizer")

chenych's avatar
chenych committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        if disable_dropout:
            disable_dropout_in_model(model)
            if ref_model is not None:
                disable_dropout_in_model(ref_model)

        self.finetuning_args = finetuning_args
        self.f_divergence_type = "reverse_kl"
        self.reference_free = False
        self.use_dpo_data_collator = True  # hack to avoid warning
        self.generate_during_eval = False  # disable at evaluation
        self.label_pad_token_id = IGNORE_INDEX
        self.padding_value = 0
        self.is_encoder_decoder = model.config.is_encoder_decoder
        self.precompute_ref_log_probs = False
        self._precomputed_train_ref_log_probs = False
        self._precomputed_eval_ref_log_probs = False
        self._peft_has_been_casted_to_bf16 = False

        self.ref_model = ref_model
        self._stored_metrics = defaultdict(lambda: defaultdict(list))

        # dpo hyperparams
        self.beta = finetuning_args.pref_beta
        self.loss_type = finetuning_args.pref_loss
        self.ftx_gamma = finetuning_args.pref_ftx
shihm's avatar
uodata  
shihm committed
82
        self.bco_gemma = finetuning_args.pref_bco_weight
chenych's avatar
chenych committed
83
84
        self.label_smoothing = finetuning_args.dpo_label_smoothing
        self.simpo_gamma = finetuning_args.simpo_gamma
chenych's avatar
chenych committed
85
        self.ld_alpha = finetuning_args.ld_alpha
chenych's avatar
chenych committed
86
87

        Trainer.__init__(self, model=model, **kwargs)
luopl's avatar
luopl committed
88
        self.model_accepts_loss_kwargs = False  # overwrite trainer's default behavior
chenych's avatar
chenych committed
89
90
91
92
93
94
95
96
97
98
        if not hasattr(self, "accelerator"):
            raise AttributeError("Please update `transformers`.")

        warnings.simplefilter("ignore")  # remove gc warnings on ref model

        if ref_model is not None:
            if self.is_deepspeed_enabled:
                if not (
                    getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
                ):  # quantized models are already set on the correct device
shihm's avatar
uodata  
shihm committed
99
                    self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
chenych's avatar
chenych committed
100
101
102
103
104
105
106
107
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
                self.ref_model.eval()

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

        if finetuning_args.use_badam:
luopl's avatar
luopl committed
108
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
chenych's avatar
chenych committed
109
110
111
112

            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)

shihm's avatar
uodata  
shihm committed
113
114
115
116
117
        if self.bco_gemma >= 1e-6:
            from trl.trainer import RunningMoments

            self.running = RunningMoments(self.accelerator)

luopl's avatar
luopl committed
118
    @override
chenych's avatar
chenych committed
119
120
121
122
123
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
        return super().create_optimizer()

luopl's avatar
luopl committed
124
    @override
chenych's avatar
chenych committed
125
126
127
128
129
130
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

luopl's avatar
luopl committed
131
    @override
chenych's avatar
chenych committed
132
    def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
luopl's avatar
luopl committed
133
134
135
        if self.finetuning_args.disable_shuffling:
            return torch.utils.data.SequentialSampler(self.train_dataset)

chenych's avatar
chenych committed
136
        return super()._get_train_sampler(*args, **kwargs)
luopl's avatar
luopl committed
137

luopl's avatar
luopl committed
138
    @override
chenych's avatar
chenych committed
139
140
141
    def get_batch_samples(self, *args, **kwargs):
        r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
        return Trainer.get_batch_samples(self, *args, **kwargs)
luopl's avatar
luopl committed
142

chenych's avatar
chenych committed
143
    def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
chenych's avatar
chenych committed
144
        r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
chenych's avatar
chenych committed
145
146
147
148
149
150
151
152
153
        log_odds = (chosen_logps - rejected_logps) - (
            torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
        )
        sft_loss = -chosen_logps
        odds_ratio_loss = -F.logsigmoid(log_odds)
        orpo_loss = sft_loss + self.beta * odds_ratio_loss
        return orpo_loss

    def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
chenych's avatar
chenych committed
154
        r"""Compute SimPO loss for batched log probabilities of the policy model."""
chenych's avatar
chenych committed
155
156
157
158
159
160
        pi_logratios = chosen_logps - rejected_logps
        gamma_logratios = self.simpo_gamma / self.beta
        logits = pi_logratios - gamma_logratios
        simpo_loss = -F.logsigmoid(self.beta * logits)
        return simpo_loss

shihm's avatar
uodata  
shihm committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    def bco_loss(
        self,
        chosen_logps: "torch.Tensor",
        rejected_logps: "torch.Tensor",
        reference_chosen_logps: "torch.Tensor",
        reference_rejected_logps: "torch.Tensor",
    ) -> "torch.Tensor":
        chosen_logratios = chosen_logps - reference_chosen_logps
        rejected_logratios = rejected_logps - reference_rejected_logps
        chosen_rewards = self.beta * chosen_logratios
        rejected_rewards = self.beta * rejected_logratios
        rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
        self.running.update(rewards)  # update baseline
        delta = self.running.mean
        bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
            -(self.beta * rejected_logratios - delta)
        )
        return bco_loss

chenych's avatar
chenych committed
180
181
182
183
184
185
    def compute_preference_loss(
        self,
        policy_chosen_logps: "torch.Tensor",
        policy_rejected_logps: "torch.Tensor",
        reference_chosen_logps: Optional["torch.Tensor"],
        reference_rejected_logps: Optional["torch.Tensor"],
chenych's avatar
chenych committed
186
187
    ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
        r"""Compute loss for preference learning."""
chenych's avatar
chenych committed
188
189
190
191
192
193
        if not self.finetuning_args.use_ref_model:
            if self.loss_type == "orpo":
                losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
            elif self.loss_type == "simpo":
                losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
            else:
luopl's avatar
luopl committed
194
                raise NotImplementedError(f"Unknown loss type: {self.loss_type}.")
chenych's avatar
chenych committed
195
196
197
198
199
200
201
202

            chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
            rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
        else:
            losses, chosen_rewards, rejected_rewards = self.dpo_loss(
                policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
            )

shihm's avatar
uodata  
shihm committed
203
204
205
206
207
208
            if self.bco_gemma > 1e-6:
                bco_losses = self.bco_loss(
                    policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
                )
                losses = (losses + bco_losses * self.bco_gemma) / (1.0 + self.bco_gemma)  # re-weight W_p and W_q

chenych's avatar
chenych committed
209
210
        return losses, chosen_rewards, rejected_rewards

luopl's avatar
luopl committed
211
    @override
chenych's avatar
chenych committed
212
    def concatenated_forward(
chenych's avatar
chenych committed
213
        self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
shihm's avatar
uodata  
shihm committed
214
    ) -> dict[str, "torch.Tensor"]:
chenych's avatar
chenych committed
215
        r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
chenych's avatar
chenych committed
216
217
218
219

        Otherwise the average log probabilities.
        """
        if self.finetuning_args.use_ref_model:
luopl's avatar
luopl committed
220
            batch = nested_detach(batch, clone=True)  # avoid error
chenych's avatar
chenych committed
221

shihm's avatar
uodata  
shihm committed
222
        labels = batch.pop("labels")  # dpo do not need compute loss in forward
chenych's avatar
chenych committed
223
        all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
chenych's avatar
chenych committed
224
        all_logps, valid_length = get_batch_logps(
shihm's avatar
uodata  
shihm committed
225
            logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
chenych's avatar
chenych committed
226
        )
chenych's avatar
chenych committed
227
228
229
230
231
232
233
        if self.loss_type in ["ipo", "orpo", "simpo"]:
            all_logps = all_logps / valid_length

        batch_size = batch["input_ids"].size(0) // 2
        chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
        chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
        chosen_length, _ = valid_length.split(batch_size, dim=0)
chenych's avatar
chenych committed
234
        if self.loss_type in ["ipo", "orpo", "simpo"]:
shihm's avatar
uodata  
shihm committed
235
            chosen_logps_avg = chosen_logps
chenych's avatar
chenych committed
236
        else:
shihm's avatar
uodata  
shihm committed
237
238
239
240
241
242
243
244
245
            chosen_logps_avg = chosen_logps / chosen_length

        return {
            "chosen_logps": chosen_logps,
            "rejected_logps": rejected_logps,
            "chosen_logits": chosen_logits,
            "rejected_logits": rejected_logits,
            "chosen_logps_avg": chosen_logps_avg,
        }
chenych's avatar
chenych committed
246

luopl's avatar
luopl committed
247
    @override
chenych's avatar
chenych committed
248
    def compute_reference_log_probs(
chenych's avatar
chenych committed
249
250
251
        self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
    ) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
        r"""Compute log probabilities of the reference model."""
chenych's avatar
chenych committed
252
253
254
255
256
257
258
259
260
261
262
        if not self.finetuning_args.use_ref_model:
            return None, None

        if self.ref_model is None:
            ref_model = model
            ref_context = self.accelerator.unwrap_model(model).disable_adapter()
        else:
            ref_model = self.ref_model
            ref_context = nullcontext()

        with torch.no_grad(), ref_context:
shihm's avatar
uodata  
shihm committed
263
264
265
            ref_output = self.concatenated_forward(ref_model, batch, is_ref_model=True)
            reference_chosen_logps = ref_output["chosen_logps"]
            reference_rejected_logps = ref_output["rejected_logps"]
chenych's avatar
chenych committed
266
267
268

        return reference_chosen_logps, reference_rejected_logps

luopl's avatar
luopl committed
269
    @override
chenych's avatar
chenych committed
270
271
272
    def get_batch_loss_metrics(
        self,
        model: "PreTrainedModel",
chenych's avatar
chenych committed
273
        batch: dict[str, "torch.Tensor"],
chenych's avatar
chenych committed
274
        train_eval: Literal["train", "eval"] = "train",
chenych's avatar
chenych committed
275
276
    ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
        r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
chenych's avatar
chenych committed
277
        metrics = {}
shihm's avatar
uodata  
shihm committed
278
279
280
281
282
283
284

        model_output = self.concatenated_forward(model, batch)
        policy_chosen_logps = model_output["chosen_logps"]
        policy_rejected_logps = model_output["rejected_logps"]
        policy_chosen_logits = model_output["chosen_logits"]
        policy_rejected_logits = model_output["rejected_logits"]
        policy_chosen_logps_avg = model_output["chosen_logps_avg"]
chenych's avatar
chenych committed
285
286
287
288
289
290
291
292
293
294
295
296
297

        reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
        losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )
        sft_loss = -policy_chosen_logps_avg
        if self.ftx_gamma > 1e-6:
            losses += self.ftx_gamma * sft_loss

        prefix = "eval_" if train_eval == "eval" else ""
luopl's avatar
luopl committed
298
299
300
301
302
303
304
305
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
        metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.mean().item()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.mean().item()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.mean().item()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.mean().item()
chenych's avatar
chenych committed
306
        if self.loss_type == "orpo":
luopl's avatar
luopl committed
307
308
            metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
            metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
chenych's avatar
chenych committed
309
310

        return losses.mean(), metrics
luopl's avatar
luopl committed
311
312

    @override
luopl's avatar
luopl committed
313
    def compute_loss(
chenych's avatar
chenych committed
314
315
316
        self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
    ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
        r"""Subclass and override to accept extra kwargs."""
chenych's avatar
chenych committed
317
        return super().compute_loss(model, inputs, return_outputs)
luopl's avatar
luopl committed
318
319

    @override
chenych's avatar
chenych committed
320
321
    def log(self, logs: dict[str, float], *args, **kwargs) -> None:
        r"""Log `logs` on the various objects watching training, including stored metrics."""
luopl's avatar
luopl committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        # logs either has "loss" or "eval_loss"
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        key_list, metric_list = [], []
        for key, metrics in self._stored_metrics[train_eval].items():
            key_list.append(key)
            metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).mean().item())

        del self._stored_metrics[train_eval]
        if len(metric_list) < 10:  # pad to for all reduce
            for i in range(10 - len(metric_list)):
                key_list.append(f"dummy_{i}")
                metric_list.append(0.0)

        metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
        metric_list = self.accelerator.reduce(metric_list, "mean").tolist()
        for key, metric in zip(key_list, metric_list):  # add remaining items
            if not key.startswith("dummy_"):
                logs[key] = metric

chenych's avatar
chenych committed
342
        return Trainer.log(self, logs, *args, **kwargs)