trainer.py 5.75 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
20
21
22
23
24
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch
from transformers import Trainer
luopl's avatar
luopl committed
25
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
26

luopl's avatar
luopl committed
27
from ...extras import logging
luopl's avatar
luopl committed
28
29
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
chenych's avatar
chenych committed
30
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
31
32
33


if TYPE_CHECKING:
chenych's avatar
chenych committed
34
    from transformers import PreTrainedModel, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
35
36
37
38
39
    from transformers.trainer import PredictionOutput

    from ...hparams import FinetuningArguments


luopl's avatar
luopl committed
40
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
41
42
43
44
45
46
47


class PairwiseTrainer(Trainer):
    r"""
    Inherits Trainer to compute pairwise loss.
    """

chenych's avatar
chenych committed
48
49
50
    def __init__(
        self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
    ) -> None:
luopl's avatar
luopl committed
51
52
53
        if is_transformers_version_greater_than("4.46"):
            kwargs["processing_class"] = kwargs.pop("tokenizer")

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
54
        super().__init__(**kwargs)
luopl's avatar
luopl committed
55
        self.model_accepts_loss_kwargs = False  # overwrite trainer's default behavior
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
56
57
        self.finetuning_args = finetuning_args
        self.can_return_loss = True  # override property to return eval_loss
chenych's avatar
chenych committed
58
59
60
61
62
        self.add_callback(FixValueHeadModelCallback)

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
63
        if finetuning_args.use_badam:
luopl's avatar
luopl committed
64
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
65

chenych's avatar
chenych committed
66
67
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
68

luopl's avatar
luopl committed
69
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
70
71
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
chenych's avatar
chenych committed
72
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
73
74
        return super().create_optimizer()

luopl's avatar
luopl committed
75
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
76
77
78
79
80
81
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

luopl's avatar
luopl committed
82
83
84
85
86
87
88
    @override
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
        if self.finetuning_args.disable_shuffling:
            return torch.utils.data.SequentialSampler(self.train_dataset)

        return super()._get_train_sampler()

luopl's avatar
luopl committed
89
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
90
    def compute_loss(
luopl's avatar
luopl committed
91
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
luopl's avatar
luopl committed
92
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
93
94
95
96
97
98
        r"""
        Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.

        Subclass and override to inject custom behavior.

        Note that the first element will be removed from the output tuple.
chenych's avatar
chenych committed
99
        See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
100
        """
chenych's avatar
chenych committed
101
        _, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
102
        batch_size = inputs["input_ids"].size(0) // 2
chenych's avatar
chenych committed
103
104
105
106
107
        chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
        chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
        chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
        rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
        chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
108

chenych's avatar
chenych committed
109
        loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
luopl's avatar
luopl committed
110

luopl's avatar
luopl committed
111
112
        if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
            loss /= self.args.gradient_accumulation_steps  # fixes the loss value for transformers 4.46.0-4.46.1
luopl's avatar
luopl committed
113

chenych's avatar
chenych committed
114
115
116
117
        if return_outputs:
            return loss, (loss, chosen_scores, rejected_scores)
        else:
            return loss
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
118
119
120
121
122
123
124
125
126
127
128

    def save_predictions(self, predict_results: "PredictionOutput") -> None:
        r"""
        Saves model predictions to `output_dir`.

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
luopl's avatar
luopl committed
129
        logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
130
131
132
133
134
135
        chosen_scores, rejected_scores = predict_results.predictions

        with open(output_prediction_file, "w", encoding="utf-8") as writer:
            res: List[str] = []
            for c_score, r_score in zip(chosen_scores, rejected_scores):
                res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
chenych's avatar
chenych committed
136

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
137
            writer.write("\n".join(res))