trainer.py 7.35 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
20
import json
import os
from types import MethodType
chenych's avatar
chenych committed
21
from typing import TYPE_CHECKING, Any, Optional, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
22
23
24
25

import numpy as np
import torch
from transformers import Seq2SeqTrainer
luopl's avatar
luopl committed
26
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
27

luopl's avatar
luopl committed
28
from ...extras import logging
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
from ...extras.constants import IGNORE_INDEX
luopl's avatar
luopl committed
30
31
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
shihm's avatar
uodata  
shihm committed
32
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
chenych's avatar
chenych committed
33
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
34
35
36


if TYPE_CHECKING:
chenych's avatar
chenych committed
37
    from torch.utils.data import Dataset
chenych's avatar
chenych committed
38
    from transformers import PreTrainedTokenizer, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
39
40
    from transformers.trainer import PredictionOutput

shihm's avatar
uodata  
shihm committed
41
    from ...hparams import FinetuningArguments, ModelArguments
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
42
43


luopl's avatar
luopl committed
44
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
45
46
47


class CustomSeq2SeqTrainer(Seq2SeqTrainer):
chenych's avatar
chenych committed
48
    r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
49

chenych's avatar
chenych committed
50
    def __init__(
chenych's avatar
chenych committed
51
52
53
        self,
        finetuning_args: "FinetuningArguments",
        processor: Optional["ProcessorMixin"],
shihm's avatar
uodata  
shihm committed
54
        model_args: Optional["ModelArguments"] = None,
chenych's avatar
chenych committed
55
        gen_kwargs: Optional[dict[str, Any]] = None,
chenych's avatar
chenych committed
56
        **kwargs,
chenych's avatar
chenych committed
57
    ) -> None:
shihm's avatar
uodata  
shihm committed
58
59
60
        # Configure FP8 environment if enabled
        if model_args is not None and model_args.fp8:
            configure_fp8_environment(model_args)
luopl's avatar
luopl committed
61
62
63
        if is_transformers_version_greater_than("4.46"):
            kwargs["processing_class"] = kwargs.pop("tokenizer")
        else:
chenych's avatar
chenych committed
64
            self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
luopl's avatar
luopl committed
65

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
66
        super().__init__(**kwargs)
chenych's avatar
chenych committed
67
        if processor is not None:
chenych's avatar
chenych committed
68
69
            # avoid wrong loss under gradient accumulation
            # https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
chenych's avatar
chenych committed
70
71
            self.model_accepts_loss_kwargs = False

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
72
        self.finetuning_args = finetuning_args
chenych's avatar
chenych committed
73
74
75
        if gen_kwargs is not None:
            # https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
            self._gen_kwargs = gen_kwargs
chenych's avatar
chenych committed
76
77
78
79

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
80
        if finetuning_args.use_badam:
luopl's avatar
luopl committed
81
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
82

chenych's avatar
chenych committed
83
84
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
85

shihm's avatar
uodata  
shihm committed
86
87
88
89
90
91
92
93
94
        if finetuning_args.use_dft_loss:
            from ..trainer_utils import dft_loss_func

            self.compute_loss_func = dft_loss_func

        # Verify FP8 status after trainer initialization (accelerator should be available)
        if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
            verify_fp8_status(self.accelerator, model_args)

luopl's avatar
luopl committed
95
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
96
97
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
chenych's avatar
chenych committed
98
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
99
100
        return super().create_optimizer()

luopl's avatar
luopl committed
101
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
102
103
104
105
106
107
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

luopl's avatar
luopl committed
108
    @override
chenych's avatar
chenych committed
109
    def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
luopl's avatar
luopl committed
110
111
112
        if self.finetuning_args.disable_shuffling:
            return torch.utils.data.SequentialSampler(self.train_dataset)

chenych's avatar
chenych committed
113
        return super()._get_train_sampler(*args, **kwargs)
luopl's avatar
luopl committed
114

chenych's avatar
chenych committed
115
116
117
118
    @override
    def compute_loss(self, model, inputs, *args, **kwargs):
        return super().compute_loss(model, inputs, *args, **kwargs)

luopl's avatar
luopl committed
119
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
120
121
122
    def prediction_step(
        self,
        model: "torch.nn.Module",
chenych's avatar
chenych committed
123
        inputs: dict[str, Union["torch.Tensor", Any]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
124
        prediction_loss_only: bool,
chenych's avatar
chenych committed
125
        ignore_keys: Optional[list[str]] = None,
luopl's avatar
luopl committed
126
        **gen_kwargs,
chenych's avatar
chenych committed
127
128
    ) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
        r"""Remove the prompt part in the generated tokens.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
129
130
131

        Subclass and override to inject custom behavior.
        """
luopl's avatar
luopl committed
132
133
134
135
136
137
138
        if self.args.predict_with_generate:  # do not pass labels to model when generate
            labels = inputs.pop("labels", None)
        else:
            labels = inputs.get("labels")

        loss, generated_tokens, _ = super().prediction_step(
            model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
139
140
        )
        if generated_tokens is not None and self.args.predict_with_generate:
luopl's avatar
luopl committed
141
            generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
142
143
144
145
            generated_tokens = generated_tokens.contiguous()

        return loss, generated_tokens, labels

luopl's avatar
luopl committed
146
147
148
    def save_predictions(
        self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
    ) -> None:
chenych's avatar
chenych committed
149
        r"""Save model predictions to `output_dir`.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
150
151
152
153
154
155
156

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
luopl's avatar
luopl committed
157
        logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
158
159

        labels = np.where(
luopl's avatar
luopl committed
160
            predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
161
162
        )
        preds = np.where(
luopl's avatar
luopl committed
163
164
165
            predict_results.predictions != IGNORE_INDEX,
            predict_results.predictions,
            self.processing_class.pad_token_id,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
166
167
168
        )

        for i in range(len(preds)):
luopl's avatar
luopl committed
169
            pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
chenych's avatar
chenych committed
170
171
            if len(pad_len):  # move pad token to last
                preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
172

luopl's avatar
luopl committed
173
174
175
        decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
        decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
        decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
chenych's avatar
chenych committed
176

luopl's avatar
luopl committed
177
178
179
        with open(output_prediction_file, "w", encoding="utf-8") as f:
            for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
                f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")