"tests/test_tokenization_utils.py" did not exist on "3d5f2913864de28a57a339c4c0c9f7b6000a7d03"
seq2seq_trainer.py 7.87 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
7
8
import logging
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler

from transformers import Trainer
9
from transformers.configuration_fsmt import FSMTConfig
Suraj Patil's avatar
Suraj Patil committed
10
from transformers.file_utils import is_torch_tpu_available
11
12
13
14
15
16
17
18
19
20
from transformers.optimization import (
    Adafactor,
    AdamW,
    get_constant_schedule,
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    get_polynomial_decay_schedule_with_warmup,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
21
from transformers.trainer_pt_utils import get_tpu_sampler
Suraj Patil's avatar
Suraj Patil committed
22
23
24
25
26
27
28
29
30
31


try:
    from .utils import label_smoothed_nll_loss
except ImportError:
    from utils import label_smoothed_nll_loss


logger = logging.getLogger(__name__)

32
33
34
35
36
37
38
39
40
41
arg_to_scheduler = {
    "linear": get_linear_schedule_with_warmup,
    "cosine": get_cosine_schedule_with_warmup,
    "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
    "polynomial": get_polynomial_decay_schedule_with_warmup,
    "constant": get_constant_schedule,
    "constant_w_warmup": get_constant_schedule_with_warmup,
}
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())

Suraj Patil's avatar
Suraj Patil committed
42
43

class Seq2SeqTrainer(Trainer):
44
    def __init__(self, config, data_args, *args, **kwargs):
45
        super().__init__(*args, **kwargs)
46
        self.config = config
47
48
        self.data_args = data_args
        self.max_gen_length = data_args.val_max_target_length
49
        self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            if self.args.adafactor:
                self.optimizer = Adafactor(
                    optimizer_grouped_parameters,
                    lr=self.args.learning_rate,
                    scale_parameter=False,
                    relative_step=False,
                )

            else:
                self.optimizer = AdamW(
                    optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon
                )

        if self.lr_scheduler is None:
84
85
86
87
88
89
90
91
92
93
94
95
            self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
        else:  # ignoring --lr_scheduler
            logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.")

    def _get_lr_scheduler(self, num_training_steps):
        schedule_func = arg_to_scheduler[self.args.lr_scheduler]
        if self.args.lr_scheduler == "constant":
            scheduler = schedule_func(self.optimizer)
        elif self.args.lr_scheduler == "constant_w_warmup":
            scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps)
        else:
            scheduler = schedule_func(
96
97
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
98
        return scheduler
99

Suraj Patil's avatar
Suraj Patil committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return get_tpu_sampler(self.train_dataset)
        else:
            if self.args.sortish_sampler:
                self.train_dataset.make_sortish_sampler(
                    self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
                )

            return (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

    def compute_loss(self, model, inputs):
        labels = inputs.pop("labels")
        outputs = model(**inputs, use_cache=False)
        logits = outputs[0]
121
        return self._compute_loss(logits, labels)
Suraj Patil's avatar
Suraj Patil committed
122

123
    def _compute_loss(self, logits, labels):
Suraj Patil's avatar
Suraj Patil committed
124
125
        if self.args.label_smoothing == 0:
            # Same behavior as modeling_bart.py
126
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
127
            assert logits.shape[-1] == self.vocab_size
Suraj Patil's avatar
Suraj Patil committed
128
129
130
131
            loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
132
                lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
Suraj Patil's avatar
Suraj Patil committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            )
        return loss

    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        inputs = self._prepare_inputs(inputs)

        with torch.no_grad():
            if self.args.predict_with_generate and not self.args.prediction_loss_only:
                generated_tokens = model.generate(
                    inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    use_cache=True,
167
168
                    num_beams=self.data_args.eval_beams,
                    max_length=self.max_gen_length,
Suraj Patil's avatar
Suraj Patil committed
169
170
                )
                # in case the batch is shorter than max length, the output should be padded
171
                generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length)
Suraj Patil's avatar
Suraj Patil committed
172
173

            labels_out = inputs.get("labels")
174
175
            # Call forward again to get loss # TODO: avoidable?
            outputs = model(**inputs, use_cache=False)
176
            loss = self._compute_loss(outputs[1], labels_out)
177
            loss = loss.mean().detach()
Suraj Patil's avatar
Suraj Patil committed
178
            if self.args.prediction_loss_only:
179
                return (loss, None, None)
Suraj Patil's avatar
Suraj Patil committed
180

181
            logits = generated_tokens if self.args.predict_with_generate else outputs[1]
Suraj Patil's avatar
Suraj Patil committed
182
183

        labels_out = labels_out.detach()
184
        labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length)
Suraj Patil's avatar
Suraj Patil committed
185
186
        return (loss, logits.detach(), labels)

187
188
    def _pad_tensors_to_max_len(self, tensor, max_length):
        padded_tensor = self.config.pad_token_id * torch.ones(
Suraj Patil's avatar
Suraj Patil committed
189
190
191
192
            (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
        )
        padded_tensor[:, : tensor.shape[-1]] = tensor
        return padded_tensor