trainer.py 6.57 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
20
21
22
23
24
25
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from transformers import Seq2SeqTrainer
luopl's avatar
luopl committed
26
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
27
28
29

from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
chenych's avatar
chenych committed
30
31
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32
33
34


if TYPE_CHECKING:
chenych's avatar
chenych committed
35
36
    from torch.utils.data import Dataset
    from transformers import ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
37
38
39
40
41
42
43
44
45
46
47
48
49
    from transformers.trainer import PredictionOutput

    from ...hparams import FinetuningArguments


logger = get_logger(__name__)


class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    r"""
    Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
    """

chenych's avatar
chenych committed
50
51
52
    def __init__(
        self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
    ) -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
53
54
        super().__init__(**kwargs)
        self.finetuning_args = finetuning_args
chenych's avatar
chenych committed
55
56
57
58
59
60
61

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

        if finetuning_args.pissa_convert:
            self.add_callback(PissaConvertCallback)

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
62
        if finetuning_args.use_badam:
chenych's avatar
chenych committed
63
            from badam import BAdamCallback, clip_grad_norm_old_version
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
64

chenych's avatar
chenych committed
65
66
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
67

luopl's avatar
luopl committed
68
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
69
70
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
chenych's avatar
chenych committed
71
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
72
73
        return super().create_optimizer()

luopl's avatar
luopl committed
74
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
75
76
77
78
79
80
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

luopl's avatar
luopl committed
81
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
82
83
84
    def prediction_step(
        self,
        model: "torch.nn.Module",
luopl's avatar
luopl committed
85
        inputs: Dict[str, Union["torch.Tensor", Any]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
86
87
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
luopl's avatar
luopl committed
88
    ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
89
90
91
92
93
        r"""
        Removes the prompt part in the generated tokens.

        Subclass and override to inject custom behavior.
        """
luopl's avatar
luopl committed
94
        labels = inputs["labels"] if "labels" in inputs else None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
95
96
        if self.args.predict_with_generate:
            assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
luopl's avatar
luopl committed
97
            labels = labels.detach().clone() if labels is not None else None  # backup labels
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
            prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
            if prompt_len > label_len:
                inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
            if label_len > prompt_len:  # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
                inputs["labels"] = inputs["labels"][:, :prompt_len]

        loss, generated_tokens, _ = super().prediction_step(  # ignore the returned labels (may be truncated)
            model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
        )
        if generated_tokens is not None and self.args.predict_with_generate:
            generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
            generated_tokens = generated_tokens.contiguous()

        return loss, generated_tokens, labels

luopl's avatar
luopl committed
113
    def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
114
115
116
117
118
119
120
121
        r"""
        Pads the tensor to the same length as the target tensor.
        """
        assert self.tokenizer.pad_token_id is not None, "Pad token is required."
        padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
        padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor  # adopt left-padding
        return padded_tensor.contiguous()  # in contiguous memory

chenych's avatar
chenych committed
122
    def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        r"""
        Saves model predictions to `output_dir`.

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
        logger.info(f"Saving prediction results to {output_prediction_file}")

        labels = np.where(
            predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
        )
        preds = np.where(
            predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
        )

        for i in range(len(preds)):
            pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
chenych's avatar
chenych committed
143
144
            if len(pad_len):  # move pad token to last
                preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
145

chenych's avatar
chenych committed
146
147
148
        decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
        decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
149
150
151

        with open(output_prediction_file, "w", encoding="utf-8") as writer:
            res: List[str] = []
chenych's avatar
chenych committed
152
153
154
            for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
                res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
155
            writer.write("\n".join(res))