trainer.py 6.69 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
20
import json
import os
from types import MethodType
chenych's avatar
chenych committed
21
from typing import TYPE_CHECKING, Any, Optional, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
22
23
24
25

import numpy as np
import torch
from transformers import Seq2SeqTrainer
luopl's avatar
luopl committed
26
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
27

luopl's avatar
luopl committed
28
from ...extras import logging
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
from ...extras.constants import IGNORE_INDEX
luopl's avatar
luopl committed
30
31
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
chenych's avatar
chenych committed
32
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
33
34
35


if TYPE_CHECKING:
chenych's avatar
chenych committed
36
    from torch.utils.data import Dataset
chenych's avatar
chenych committed
37
    from transformers import PreTrainedTokenizer, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
38
39
40
41
42
    from transformers.trainer import PredictionOutput

    from ...hparams import FinetuningArguments


luopl's avatar
luopl committed
43
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
44
45
46


class CustomSeq2SeqTrainer(Seq2SeqTrainer):
chenych's avatar
chenych committed
47
    r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
48

chenych's avatar
chenych committed
49
    def __init__(
chenych's avatar
chenych committed
50
51
52
        self,
        finetuning_args: "FinetuningArguments",
        processor: Optional["ProcessorMixin"],
chenych's avatar
chenych committed
53
        gen_kwargs: Optional[dict[str, Any]] = None,
chenych's avatar
chenych committed
54
        **kwargs,
chenych's avatar
chenych committed
55
    ) -> None:
luopl's avatar
luopl committed
56
57
58
        if is_transformers_version_greater_than("4.46"):
            kwargs["processing_class"] = kwargs.pop("tokenizer")
        else:
chenych's avatar
chenych committed
59
            self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
luopl's avatar
luopl committed
60

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
61
        super().__init__(**kwargs)
chenych's avatar
chenych committed
62
        if processor is not None:
chenych's avatar
chenych committed
63
64
            # avoid wrong loss under gradient accumulation
            # https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
chenych's avatar
chenych committed
65
66
            self.model_accepts_loss_kwargs = False

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
67
        self.finetuning_args = finetuning_args
chenych's avatar
chenych committed
68
69
70
        if gen_kwargs is not None:
            # https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
            self._gen_kwargs = gen_kwargs
chenych's avatar
chenych committed
71
72
73
74

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
75
        if finetuning_args.use_badam:
luopl's avatar
luopl committed
76
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77

chenych's avatar
chenych committed
78
79
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
80

luopl's avatar
luopl committed
81
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
82
83
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
chenych's avatar
chenych committed
84
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
85
86
        return super().create_optimizer()

luopl's avatar
luopl committed
87
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
88
89
90
91
92
93
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

luopl's avatar
luopl committed
94
    @override
chenych's avatar
chenych committed
95
    def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
luopl's avatar
luopl committed
96
97
98
        if self.finetuning_args.disable_shuffling:
            return torch.utils.data.SequentialSampler(self.train_dataset)

chenych's avatar
chenych committed
99
        return super()._get_train_sampler(*args, **kwargs)
luopl's avatar
luopl committed
100

chenych's avatar
chenych committed
101
102
103
104
    @override
    def compute_loss(self, model, inputs, *args, **kwargs):
        return super().compute_loss(model, inputs, *args, **kwargs)

luopl's avatar
luopl committed
105
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
106
107
108
    def prediction_step(
        self,
        model: "torch.nn.Module",
chenych's avatar
chenych committed
109
        inputs: dict[str, Union["torch.Tensor", Any]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
110
        prediction_loss_only: bool,
chenych's avatar
chenych committed
111
        ignore_keys: Optional[list[str]] = None,
luopl's avatar
luopl committed
112
        **gen_kwargs,
chenych's avatar
chenych committed
113
114
    ) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
        r"""Remove the prompt part in the generated tokens.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
115
116
117

        Subclass and override to inject custom behavior.
        """
luopl's avatar
luopl committed
118
119
120
121
122
123
124
        if self.args.predict_with_generate:  # do not pass labels to model when generate
            labels = inputs.pop("labels", None)
        else:
            labels = inputs.get("labels")

        loss, generated_tokens, _ = super().prediction_step(
            model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
125
126
        )
        if generated_tokens is not None and self.args.predict_with_generate:
luopl's avatar
luopl committed
127
            generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
128
129
130
131
            generated_tokens = generated_tokens.contiguous()

        return loss, generated_tokens, labels

luopl's avatar
luopl committed
132
133
134
    def save_predictions(
        self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
    ) -> None:
chenych's avatar
chenych committed
135
        r"""Save model predictions to `output_dir`.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
136
137
138
139
140
141
142

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
luopl's avatar
luopl committed
143
        logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
144
145

        labels = np.where(
luopl's avatar
luopl committed
146
            predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
147
148
        )
        preds = np.where(
luopl's avatar
luopl committed
149
150
151
            predict_results.predictions != IGNORE_INDEX,
            predict_results.predictions,
            self.processing_class.pad_token_id,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
152
153
154
        )

        for i in range(len(preds)):
luopl's avatar
luopl committed
155
            pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
chenych's avatar
chenych committed
156
157
            if len(pad_len):  # move pad token to last
                preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
158

luopl's avatar
luopl committed
159
160
161
        decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
        decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
        decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
chenych's avatar
chenych committed
162

luopl's avatar
luopl committed
163
164
165
        with open(output_prediction_file, "w", encoding="utf-8") as f:
            for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
                f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")