trainer.py 5.47 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
20
import json
import os
from types import MethodType
chenych's avatar
chenych committed
21
from typing import TYPE_CHECKING, Optional, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
22
23
24

import torch
from transformers import Trainer
luopl's avatar
luopl committed
25
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
26

luopl's avatar
luopl committed
27
from ...extras import logging
chenych's avatar
chenych committed
28
from ...extras.packages import is_transformers_version_greater_than
luopl's avatar
luopl committed
29
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
chenych's avatar
chenych committed
30
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
31
32
33


if TYPE_CHECKING:
chenych's avatar
chenych committed
34
    from transformers import PreTrainedModel, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
35
36
37
38
39
    from transformers.trainer import PredictionOutput

    from ...hparams import FinetuningArguments


luopl's avatar
luopl committed
40
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
41
42
43


class PairwiseTrainer(Trainer):
chenych's avatar
chenych committed
44
    r"""Inherits Trainer to compute pairwise loss."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
45

chenych's avatar
chenych committed
46
47
48
    def __init__(
        self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
    ) -> None:
luopl's avatar
luopl committed
49
50
51
        if is_transformers_version_greater_than("4.46"):
            kwargs["processing_class"] = kwargs.pop("tokenizer")

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
52
        super().__init__(**kwargs)
luopl's avatar
luopl committed
53
        self.model_accepts_loss_kwargs = False  # overwrite trainer's default behavior
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
54
55
        self.finetuning_args = finetuning_args
        self.can_return_loss = True  # override property to return eval_loss
chenych's avatar
chenych committed
56
57
58
59
60
        self.add_callback(FixValueHeadModelCallback)

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
61
        if finetuning_args.use_badam:
luopl's avatar
luopl committed
62
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
63

chenych's avatar
chenych committed
64
65
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
66

luopl's avatar
luopl committed
67
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
68
69
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
chenych's avatar
chenych committed
70
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
71
72
        return super().create_optimizer()

luopl's avatar
luopl committed
73
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
74
75
76
77
78
79
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

luopl's avatar
luopl committed
80
81
82
83
84
85
86
    @override
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
        if self.finetuning_args.disable_shuffling:
            return torch.utils.data.SequentialSampler(self.train_dataset)

        return super()._get_train_sampler()

luopl's avatar
luopl committed
87
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
88
    def compute_loss(
chenych's avatar
chenych committed
89
90
91
        self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
    ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
        r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
92
93
94
95

        Subclass and override to inject custom behavior.

        Note that the first element will be removed from the output tuple.
chenych's avatar
chenych committed
96
        See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
97
        """
chenych's avatar
chenych committed
98
        _, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
99
        batch_size = inputs["input_ids"].size(0) // 2
chenych's avatar
chenych committed
100
101
102
103
104
        chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
        chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
        chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
        rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
        chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
105

chenych's avatar
chenych committed
106
107
108
109
110
        loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
        if return_outputs:
            return loss, (loss, chosen_scores, rejected_scores)
        else:
            return loss
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
111
112

    def save_predictions(self, predict_results: "PredictionOutput") -> None:
chenych's avatar
chenych committed
113
        r"""Save model predictions to `output_dir`.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
114
115
116
117
118
119
120

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
luopl's avatar
luopl committed
121
        logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
122
123
124
        chosen_scores, rejected_scores = predict_results.predictions

        with open(output_prediction_file, "w", encoding="utf-8") as writer:
chenych's avatar
chenych committed
125
            res: list[str] = []
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
126
127
            for c_score, r_score in zip(chosen_scores, rejected_scores):
                res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
chenych's avatar
chenych committed
128

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
129
            writer.write("\n".join(res))