seq2seq_trainer.py 8.98 KB
Newer Older
1
import copy
Suraj Patil's avatar
Suraj Patil committed
2
3
4
5
6
7
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler

8
from transformers import PreTrainedModel, Trainer, logging
9
from transformers.configuration_fsmt import FSMTConfig
Suraj Patil's avatar
Suraj Patil committed
10
from transformers.file_utils import is_torch_tpu_available
11
12
13
14
15
16
17
18
19
20
from transformers.optimization import (
    Adafactor,
    AdamW,
    get_constant_schedule,
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    get_polynomial_decay_schedule_with_warmup,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
21
from transformers.trainer_pt_utils import get_tpu_sampler
Suraj Patil's avatar
Suraj Patil committed
22
23
24
25
26
27
28
29


try:
    from .utils import label_smoothed_nll_loss
except ImportError:
    from utils import label_smoothed_nll_loss


30
logger = logging.get_logger(__name__)
Suraj Patil's avatar
Suraj Patil committed
31

32
33
34
35
36
37
38
39
40
41
arg_to_scheduler = {
    "linear": get_linear_schedule_with_warmup,
    "cosine": get_cosine_schedule_with_warmup,
    "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
    "polynomial": get_polynomial_decay_schedule_with_warmup,
    "constant": get_constant_schedule,
    "constant_w_warmup": get_constant_schedule_with_warmup,
}
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())

Suraj Patil's avatar
Suraj Patil committed
42
43

class Seq2SeqTrainer(Trainer):
44
    def __init__(self, config=None, data_args=None, *args, **kwargs):
45
        super().__init__(*args, **kwargs)
46
47
48
49
50
51
52
53
54

        if config is None:
            assert isinstance(
                self.model, PreTrainedModel
            ), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
            self.config = self._actual_model(self.model).config
        else:
            self.config = config

55
        self.data_args = data_args
56
        self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
57

58
59
60
61
62
        if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
            assert (
                self.config.pad_token_id is not None
            ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            if self.args.adafactor:
                self.optimizer = Adafactor(
                    optimizer_grouped_parameters,
                    lr=self.args.learning_rate,
                    scale_parameter=False,
                    relative_step=False,
                )

            else:
                self.optimizer = AdamW(
                    optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon
                )

        if self.lr_scheduler is None:
96
97
98
99
100
101
102
103
104
105
106
107
            self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
        else:  # ignoring --lr_scheduler
            logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.")

    def _get_lr_scheduler(self, num_training_steps):
        schedule_func = arg_to_scheduler[self.args.lr_scheduler]
        if self.args.lr_scheduler == "constant":
            scheduler = schedule_func(self.optimizer)
        elif self.args.lr_scheduler == "constant_w_warmup":
            scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps)
        else:
            scheduler = schedule_func(
108
109
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
110
        return scheduler
111

Suraj Patil's avatar
Suraj Patil committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return get_tpu_sampler(self.train_dataset)
        else:
            if self.args.sortish_sampler:
                self.train_dataset.make_sortish_sampler(
                    self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
                )

            return (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

129
130
    def _compute_loss(self, model, inputs):
        inputs = copy.deepcopy(inputs)
Suraj Patil's avatar
Suraj Patil committed
131
        if self.args.label_smoothing == 0:
132
133
134
135
136
137
138
139
140
141
            if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
                # force training to ignore pad token
                labels = inputs.pop("labels")
                logits = model(**inputs, use_cache=False)[0]

                loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
                loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
            else:
                # compute usual loss via models
                loss, logits = model(**inputs, use_cache=False)[:2]
Suraj Patil's avatar
Suraj Patil committed
142
        else:
143
144
145
            # compute label smoothed loss
            labels = inputs.pop("labels")
            logits = model(**inputs, use_cache=False)[0]
Suraj Patil's avatar
Suraj Patil committed
146
            lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
147
            loss, _ = label_smoothed_nll_loss(
148
                lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
Suraj Patil's avatar
Suraj Patil committed
149
            )
150
151
152
153
        return loss, logits

    def compute_loss(self, model, inputs):
        loss, _ = self._compute_loss(model, inputs)
Suraj Patil's avatar
Suraj Patil committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        return loss

    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        inputs = self._prepare_inputs(inputs)

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        if self.args.predict_with_generate and not self.args.prediction_loss_only:
            gen_kwargs = {
                "max_length": self.data_args.val_max_target_length
                if self.data_args is not None
                else self.config.max_length,
                "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
            }
            generated_tokens = model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **gen_kwargs,
            )
            # in case the batch is shorter than max length, the output should be padded
            if self.config.pad_token_id is not None:
                generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

            # compute loss on predict data
Suraj Patil's avatar
Suraj Patil committed
198
        with torch.no_grad():
199
200
201
202
203
204
205
206
207
208
209
210
211
            loss, logits = self._compute_loss(model, inputs)

        loss = loss.mean().detach()
        if self.args.prediction_loss_only:
            return (loss, None, None)

        logits = generated_tokens if self.args.predict_with_generate else logits

        labels = inputs["labels"]
        if self.config.pad_token_id is not None:
            labels = self._pad_tensors_to_max_len(labels, self.config.max_length)

        return (loss, logits, labels)
Suraj Patil's avatar
Suraj Patil committed
212

213
214
    def _pad_tensors_to_max_len(self, tensor, max_length):
        padded_tensor = self.config.pad_token_id * torch.ones(
Suraj Patil's avatar
Suraj Patil committed
215
216
217
218
            (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
        )
        padded_tensor[:, : tensor.shape[-1]] = tensor
        return padded_tensor