trainer.py 6.93 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
20
21
22
23
24
25
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from transformers import Seq2SeqTrainer
luopl's avatar
luopl committed
26
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
27

luopl's avatar
luopl committed
28
from ...extras import logging
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
from ...extras.constants import IGNORE_INDEX
luopl's avatar
luopl committed
30
31
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
chenych's avatar
chenych committed
32
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
33
34
35


if TYPE_CHECKING:
chenych's avatar
chenych committed
36
    from torch.utils.data import Dataset
luopl's avatar
luopl committed
37
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
38
39
40
41
42
    from transformers.trainer import PredictionOutput

    from ...hparams import FinetuningArguments


luopl's avatar
luopl committed
43
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
44
45
46
47
48
49
50


class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    r"""
    Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
    """

chenych's avatar
chenych committed
51
52
53
    def __init__(
        self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
    ) -> None:
luopl's avatar
luopl committed
54
55
56
57
58
        if is_transformers_version_greater_than("4.46"):
            kwargs["processing_class"] = kwargs.pop("tokenizer")
        else:
            self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
59
60
        super().__init__(**kwargs)
        self.finetuning_args = finetuning_args
chenych's avatar
chenych committed
61
62
63
64

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
65
        if finetuning_args.use_badam:
luopl's avatar
luopl committed
66
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
67

chenych's avatar
chenych committed
68
69
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
70

luopl's avatar
luopl committed
71
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
72
73
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
chenych's avatar
chenych committed
74
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
75
76
        return super().create_optimizer()

luopl's avatar
luopl committed
77
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
78
79
80
81
82
83
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

luopl's avatar
luopl committed
84
    @override
luopl's avatar
luopl committed
85
86
87
88
89
90
91
92
93
94
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
        if self.finetuning_args.disable_shuffling:
            return torch.utils.data.SequentialSampler(self.train_dataset)

        return super()._get_train_sampler()

    @override
    def compute_loss(
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
luopl's avatar
luopl committed
95
        r"""
luopl's avatar
luopl committed
96
97
98
        Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.

        It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
luopl's avatar
luopl committed
99
100
        """
        loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
luopl's avatar
luopl committed
101
        if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
luopl's avatar
luopl committed
102
            if return_outputs:
luopl's avatar
luopl committed
103
                loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
luopl's avatar
luopl committed
104
            else:
luopl's avatar
luopl committed
105
                loss = loss / self.args.gradient_accumulation_steps
luopl's avatar
luopl committed
106
107
108

        return loss

luopl's avatar
luopl committed
109
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
110
111
112
    def prediction_step(
        self,
        model: "torch.nn.Module",
luopl's avatar
luopl committed
113
        inputs: Dict[str, Union["torch.Tensor", Any]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
114
115
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
luopl's avatar
luopl committed
116
        **gen_kwargs,
luopl's avatar
luopl committed
117
    ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
118
119
120
121
122
        r"""
        Removes the prompt part in the generated tokens.

        Subclass and override to inject custom behavior.
        """
luopl's avatar
luopl committed
123
124
125
126
127
128
129
        if self.args.predict_with_generate:  # do not pass labels to model when generate
            labels = inputs.pop("labels", None)
        else:
            labels = inputs.get("labels")

        loss, generated_tokens, _ = super().prediction_step(
            model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
130
131
        )
        if generated_tokens is not None and self.args.predict_with_generate:
luopl's avatar
luopl committed
132
            generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
133
134
135
136
            generated_tokens = generated_tokens.contiguous()

        return loss, generated_tokens, labels

luopl's avatar
luopl committed
137
138
139
    def save_predictions(
        self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
    ) -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
140
141
142
143
144
145
146
147
148
        r"""
        Saves model predictions to `output_dir`.

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
luopl's avatar
luopl committed
149
        logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
150
151

        labels = np.where(
luopl's avatar
luopl committed
152
            predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
153
154
        )
        preds = np.where(
luopl's avatar
luopl committed
155
156
157
            predict_results.predictions != IGNORE_INDEX,
            predict_results.predictions,
            self.processing_class.pad_token_id,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
158
159
160
        )

        for i in range(len(preds)):
luopl's avatar
luopl committed
161
            pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
chenych's avatar
chenych committed
162
163
            if len(pad_len):  # move pad token to last
                preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
164

luopl's avatar
luopl committed
165
166
167
        decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
        decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
        decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
chenych's avatar
chenych committed
168

luopl's avatar
luopl committed
169
170
171
        with open(output_prediction_file, "w", encoding="utf-8") as f:
            for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
                f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")