seq2seq_trainer.py 4.96 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import logging
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler

from transformers import Trainer
from transformers.file_utils import is_torch_tpu_available
from transformers.trainer import get_tpu_sampler


try:
    from .utils import label_smoothed_nll_loss
except ImportError:
    from utils import label_smoothed_nll_loss


logger = logging.getLogger(__name__)


class Seq2SeqTrainer(Trainer):
23
24
25
26
27
28
    def __init__(self, data_args, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data_args = data_args
        self.max_gen_length = data_args.val_max_target_length
        self.pad_token_id = self.model.config.pad_token_id

Suraj Patil's avatar
Suraj Patil committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return get_tpu_sampler(self.train_dataset)
        else:
            if self.args.sortish_sampler:
                self.train_dataset.make_sortish_sampler(
                    self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
                )

            return (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

    def compute_loss(self, model, inputs):
        labels = inputs.pop("labels")
        outputs = model(**inputs, use_cache=False)
        logits = outputs[0]
50
        return self._compute_loss(logits, labels, ignore_index=self.pad_token_id)
Suraj Patil's avatar
Suraj Patil committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    def _compute_loss(self, logits, labels, ignore_index):
        if self.args.label_smoothing == 0:
            # Same behavior as modeling_bart.py
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
            assert logits.shape[-1] == self.model.config.vocab_size
            loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index
            )
        return loss

    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        inputs = self._prepare_inputs(inputs)

        with torch.no_grad():
            if self.args.predict_with_generate and not self.args.prediction_loss_only:
                generated_tokens = model.generate(
                    inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    use_cache=True,
96
97
                    num_beams=self.data_args.eval_beams,
                    max_length=self.max_gen_length,
Suraj Patil's avatar
Suraj Patil committed
98
99
100
                )
                # in case the batch is shorter than max length, the output should be padded
                generated_tokens = self._pad_tensors_to_max_len(
101
                    generated_tokens, self.max_gen_length, self.pad_token_id
Suraj Patil's avatar
Suraj Patil committed
102
103
104
                )

            labels_out = inputs.get("labels")
105
106
107
            # Call forward again to get loss # TODO: avoidable?
            outputs = model(**inputs, use_cache=False)
            loss = self._compute_loss(outputs[1], labels_out, self.pad_token_id)
Suraj Patil's avatar
Suraj Patil committed
108
109
            loss = loss.mean().item()
            if self.args.prediction_loss_only:
110
                return (loss, None, None)
Suraj Patil's avatar
Suraj Patil committed
111

112
            logits = generated_tokens if self.args.predict_with_generate else outputs[1]
Suraj Patil's avatar
Suraj Patil committed
113
114

        labels_out = labels_out.detach()
115
        labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length, self.pad_token_id)
Suraj Patil's avatar
Suraj Patil committed
116
117
118
119
120
121
122
123
        return (loss, logits.detach(), labels)

    def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
        padded_tensor = pad_token_id * torch.ones(
            (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
        )
        padded_tensor[:, : tensor.shape[-1]] = tensor
        return padded_tensor