seq2seq_trainer.py 11 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from typing import Any, Dict, List, Optional, Tuple, Union
Suraj Patil's avatar
Suraj Patil committed
16
17
18
19
20

import torch
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler

21
from transformers import PreTrainedModel, Trainer, logging
22
from transformers.integrations import is_fairscale_available
Sylvain Gugger's avatar
Sylvain Gugger committed
23
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
24
25
26
27
28
29
30
31
32
33
from transformers.optimization import (
    Adafactor,
    AdamW,
    get_constant_schedule,
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    get_polynomial_decay_schedule_with_warmup,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
34
from transformers.trainer_pt_utils import get_tpu_sampler
35
from transformers.training_args import ParallelMode
36
from transformers.utils import is_torch_tpu_available
Suraj Patil's avatar
Suraj Patil committed
37
38


39
40
41
42
if is_fairscale_available():
    from fairscale.optim import OSS


43
logger = logging.get_logger(__name__)
Suraj Patil's avatar
Suraj Patil committed
44

45
46
47
48
49
50
51
52
53
arg_to_scheduler = {
    "linear": get_linear_schedule_with_warmup,
    "cosine": get_cosine_schedule_with_warmup,
    "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
    "polynomial": get_polynomial_decay_schedule_with_warmup,
    "constant": get_constant_schedule,
    "constant_w_warmup": get_constant_schedule_with_warmup,
}

Suraj Patil's avatar
Suraj Patil committed
54
55

class Seq2SeqTrainer(Trainer):
56
    def __init__(self, config=None, data_args=None, *args, **kwargs):
57
        super().__init__(*args, **kwargs)
58
59

        if config is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
60
61
62
63
            assert isinstance(self.model, PreTrainedModel), (
                "If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is"
                f" {self.model.__class__}"
            )
64
            self.config = self.model.config
65
66
67
        else:
            self.config = config

68
        self.data_args = data_args
69
        self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
70

71
        if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
Sylvain Gugger's avatar
Sylvain Gugger committed
72
73
74
75
            assert self.config.pad_token_id is not None, (
                "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss"
                " calculation or doing label smoothing."
            )
76

77
        if self.config.pad_token_id is None and self.config.eos_token_id is not None:
78
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
79
80
                f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for"
                " padding.."
81
82
            )

83
84
85
86
        if self.args.label_smoothing == 0:
            self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
        else:
            # dynamically import label_smoothed_nll_loss
87
            from utils import label_smoothed_nll_loss
88
89
90

            self.loss_fn = label_smoothed_nll_loss

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
110
            optimizer_cls = Adafactor if self.args.adafactor else AdamW
111
            if self.args.adafactor:
112
113
                optimizer_cls = Adafactor
                optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
114
            else:
115
116
117
118
119
120
                optimizer_cls = AdamW
                optimizer_kwargs = {
                    "betas": (self.args.adam_beta1, self.args.adam_beta2),
                    "eps": self.args.adam_epsilon,
                }
            optimizer_kwargs["lr"] = self.args.learning_rate
121
            if self.sharded_ddp:
122
123
124
125
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
                    optim=optimizer_cls,
                    **optimizer_kwargs,
126
                )
127
128
            else:
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
129
130

        if self.lr_scheduler is None:
131
132
            self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
        else:  # ignoring --lr_scheduler
133
            logger.warning("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.")
134
135
136
137
138
139
140
141
142

    def _get_lr_scheduler(self, num_training_steps):
        schedule_func = arg_to_scheduler[self.args.lr_scheduler]
        if self.args.lr_scheduler == "constant":
            scheduler = schedule_func(self.optimizer)
        elif self.args.lr_scheduler == "constant_w_warmup":
            scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps)
        else:
            scheduler = schedule_func(
143
144
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
145
        return scheduler
146

147
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
Suraj Patil's avatar
Suraj Patil committed
148
149
150
151
152
153
154
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return get_tpu_sampler(self.train_dataset)
        else:
            if self.args.sortish_sampler:
                self.train_dataset.make_sortish_sampler(
155
                    self.args.per_device_train_batch_size,
156
                    distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
Suraj Patil's avatar
Suraj Patil committed
157
158
159
160
161
162
163
164
                )

            return (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

165
    def _compute_loss(self, model, inputs, labels):
Suraj Patil's avatar
Suraj Patil committed
166
        if self.args.label_smoothing == 0:
167
168
169
            if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
                # force training to ignore pad token
                logits = model(**inputs, use_cache=False)[0]
170
                loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
171
172
            else:
                # compute usual loss via models
173
                loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
Suraj Patil's avatar
Suraj Patil committed
174
        else:
175
176
            # compute label smoothed loss
            logits = model(**inputs, use_cache=False)[0]
Suraj Patil's avatar
Suraj Patil committed
177
            lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
178
            loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id)
179
180
181
        return loss, logits

    def compute_loss(self, model, inputs):
182
183
        labels = inputs.pop("labels")
        loss, _ = self._compute_loss(model, inputs, labels)
Suraj Patil's avatar
Suraj Patil committed
184
185
186
        return loss

    def prediction_step(
187
188
189
190
191
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
Suraj Patil's avatar
Suraj Patil committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        inputs = self._prepare_inputs(inputs)

215
216
217
218
219
220
221
        gen_kwargs = {
            "max_length": self.data_args.val_max_target_length
            if self.data_args is not None
            else self.config.max_length,
            "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
        }

222
        if self.args.predict_with_generate and not self.args.prediction_loss_only:
223
            generated_tokens = self.model.generate(
224
225
226
227
228
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **gen_kwargs,
            )
            # in case the batch is shorter than max length, the output should be padded
229
            if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
230
231
                generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

232
        labels = inputs.pop("labels")
Suraj Patil's avatar
Suraj Patil committed
233
        with torch.no_grad():
234
235
            # compute loss on predict data
            loss, logits = self._compute_loss(model, inputs, labels)
236
237
238
239
240
241
242

        loss = loss.mean().detach()
        if self.args.prediction_loss_only:
            return (loss, None, None)

        logits = generated_tokens if self.args.predict_with_generate else logits

243
244
        if labels.shape[-1] < gen_kwargs["max_length"]:
            labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
245
246

        return (loss, logits, labels)
Suraj Patil's avatar
Suraj Patil committed
247

248
    def _pad_tensors_to_max_len(self, tensor, max_length):
249
250
251
252
253
        # If PAD token is not defined at least EOS token has to be defined
        pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id

        if pad_token_id is None:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
254
255
                "Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be"
                f" padded to `max_length`={max_length}"
256
257
258
            )

        padded_tensor = pad_token_id * torch.ones(
Suraj Patil's avatar
Suraj Patil committed
259
260
261
262
            (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
        )
        padded_tensor[:, : tensor.shape[-1]] = tensor
        return padded_tensor