trainer.py 5.37 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
20
21
22
23
24
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch
from transformers import Trainer
luopl's avatar
luopl committed
25
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
26

luopl's avatar
luopl committed
27
28
from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46
chenych's avatar
chenych committed
29
30
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
31
32
33


if TYPE_CHECKING:
chenych's avatar
chenych committed
34
    from transformers import PreTrainedModel, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
35
36
37
38
39
    from transformers.trainer import PredictionOutput

    from ...hparams import FinetuningArguments


luopl's avatar
luopl committed
40
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
41
42
43
44
45
46
47


class PairwiseTrainer(Trainer):
    r"""
    Inherits Trainer to compute pairwise loss.
    """

chenych's avatar
chenych committed
48
49
50
    def __init__(
        self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
    ) -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
51
52
53
        super().__init__(**kwargs)
        self.finetuning_args = finetuning_args
        self.can_return_loss = True  # override property to return eval_loss
chenych's avatar
chenych committed
54
55
56
57
58
59
60
61
        self.add_callback(FixValueHeadModelCallback)

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

        if finetuning_args.pissa_convert:
            self.add_callback(PissaConvertCallback)

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
62
        if finetuning_args.use_badam:
luopl's avatar
luopl committed
63
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
64

chenych's avatar
chenych committed
65
66
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
67

luopl's avatar
luopl committed
68
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
69
70
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
chenych's avatar
chenych committed
71
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
72
73
        return super().create_optimizer()

luopl's avatar
luopl committed
74
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
75
76
77
78
79
80
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

luopl's avatar
luopl committed
81
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
82
    def compute_loss(
luopl's avatar
luopl committed
83
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
luopl's avatar
luopl committed
84
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
85
86
87
88
89
90
        r"""
        Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.

        Subclass and override to inject custom behavior.

        Note that the first element will be removed from the output tuple.
chenych's avatar
chenych committed
91
        See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
92
        """
chenych's avatar
chenych committed
93
        _, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
94
        batch_size = inputs["input_ids"].size(0) // 2
chenych's avatar
chenych committed
95
96
97
98
99
        chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
        chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
        chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
        rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
        chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
100

chenych's avatar
chenych committed
101
        loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
luopl's avatar
luopl committed
102
103
104
105

        if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
            loss /= self.args.gradient_accumulation_steps  # fixes the loss value for transformers 4.46.0

chenych's avatar
chenych committed
106
107
108
109
        if return_outputs:
            return loss, (loss, chosen_scores, rejected_scores)
        else:
            return loss
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
110
111
112
113
114
115
116
117
118
119
120

    def save_predictions(self, predict_results: "PredictionOutput") -> None:
        r"""
        Saves model predictions to `output_dir`.

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
luopl's avatar
luopl committed
121
        logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
122
123
124
125
126
127
        chosen_scores, rejected_scores = predict_results.predictions

        with open(output_prediction_file, "w", encoding="utf-8") as writer:
            res: List[str] = []
            for c_score, r_score in zip(chosen_scores, rejected_scores):
                res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
chenych's avatar
chenych committed
128

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
129
            writer.write("\n".join(res))