fairseq_task.py 9.66 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Peng-Jen Chen's avatar
Peng-Jen Chen committed
8
import torch
9

10
11
12
from fairseq import tokenizer
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary

Myle Ott's avatar
Myle Ott committed
13
14
15

class FairseqTask(object):
    """
16
17
    Tasks store dictionaries and provide helpers for loading/iterating over
    Datasets, initializing the Model/Criterion and calculating the loss.
Myle Ott's avatar
Myle Ott committed
18
19
20
21
22
23
24
25
26
27
28
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        pass

    def __init__(self, args):
        self.args = args
        self.datasets = {}

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    @classmethod
    def load_dictionary(cls, filename):
        """Load the dictionary from the filename

        Args:
            filename (str): the filename
        """
        return Dictionary.load(filename)

    @classmethod
    def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
        """Build the dictionary

        Args:
            filenames (list): list of filenames
            workers (int): number of concurrent workers
            threshold (int): defines the minimum word count
            nwords (int): defines the total number of words in the final dictionary,
                including special symbols
            padding_factor (int): can be used to pad the dictionary size to be a
                multiple of 8, which is important on some hardware (e.g., Nvidia
                Tensor Cores).
        """
        d = Dictionary()
        for filename in filenames:
54
            Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
55
56
57
        d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
        return d

Myle Ott's avatar
Myle Ott committed
58
59
    @classmethod
    def setup_task(cls, args, **kwargs):
Myle Ott's avatar
Myle Ott committed
60
61
62
63
64
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
65
        return cls(args)
Myle Ott's avatar
Myle Ott committed
66

Peng-Jen Chen's avatar
Peng-Jen Chen committed
67
    def load_dataset(self, split, combine=False, **kwargs):
Myle Ott's avatar
Myle Ott committed
68
69
70
71
72
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
Myle Ott's avatar
Myle Ott committed
73
74
75
        raise NotImplementedError

    def dataset(self, split):
Myle Ott's avatar
Myle Ott committed
76
77
78
79
80
81
82
83
84
        """
        Return a loaded dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)

        Returns:
            a :class:`~fairseq.data.FairseqDataset` corresponding to *split*
        """
Myle Ott's avatar
Myle Ott committed
85
        from fairseq.data import FairseqDataset
Myle Ott's avatar
Myle Ott committed
86
87
88
89
90
91
        if split not in self.datasets:
            raise KeyError('Dataset not loaded: ' + split)
        if not isinstance(self.datasets[split], FairseqDataset):
            raise TypeError('Datasets are expected to be of type FairseqDataset')
        return self.datasets[split]

92
    def get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
93
94
95
        self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
        ignore_invalid_inputs=False, required_batch_size_multiple=1,
        seed=1, num_shards=1, shard_id=0, num_workers=0,
96
97
    ):
        """
Myle Ott's avatar
Myle Ott committed
98
        Get an iterator that yields batches of data from the given dataset.
99
100

        Args:
Myle Ott's avatar
Myle Ott committed
101
            dataset (~fairseq.data.FairseqDataset): dataset to batch
Myle Ott's avatar
Myle Ott committed
102
103
            max_tokens (int, optional): max number of tokens in each batch
                (default: None).
104
            max_sentences (int, optional): max number of sentences in each
Myle Ott's avatar
Myle Ott committed
105
                batch (default: None).
106
            max_positions (optional): max sentence length supported by the
Myle Ott's avatar
Myle Ott committed
107
                model (default: None).
108
            ignore_invalid_inputs (bool, optional): don't raise Exception for
Myle Ott's avatar
Myle Ott committed
109
                sentences that are too long (default: False).
110
            required_batch_size_multiple (int, optional): require batch size to
Myle Ott's avatar
Myle Ott committed
111
                be a multiple of N (default: 1).
112
            seed (int, optional): seed for random number generator for
Myle Ott's avatar
Myle Ott committed
113
                reproducibility (default: 1).
114
            num_shards (int, optional): shard the data iterator into N
Myle Ott's avatar
Myle Ott committed
115
                shards (default: 1).
116
            shard_id (int, optional): which shard of the data iterator to
Myle Ott's avatar
Myle Ott committed
117
118
119
120
                return (default: 0).
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means the data will be loaded in the main process
                (default: 0).
121
122

        Returns:
Myle Ott's avatar
Myle Ott committed
123
124
            ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        """
        assert isinstance(dataset, FairseqDataset)

        # get indices ordered by example size
        with data_utils.numpy_seed(seed):
            indices = dataset.ordered_indices()

        # filter examples that are too large
        indices = data_utils.filter_by_size(
            indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs),
        )

        # create mini-batches with given size constraints
        batch_sampler = data_utils.batch_by_size(
            indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
        )

        # return a reusable, sharded iterator
144
        return iterators.EpochBatchIterator(
145
            dataset=dataset,
146
            collate_fn=dataset.collater,
147
148
149
150
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
Myle Ott's avatar
Myle Ott committed
151
            num_workers=num_workers,
152
153
        )

Myle Ott's avatar
Myle Ott committed
154
    def build_model(self, args):
Myle Ott's avatar
Myle Ott committed
155
156
157
158
159
160
161
162
163
164
        """
        Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
        task.

        Args:
            args (argparse.Namespace): parsed command-line arguments

        Returns:
            a :class:`~fairseq.models.BaseFairseqModel` instance
        """
Myle Ott's avatar
Myle Ott committed
165
        from fairseq import models
Myle Ott's avatar
Myle Ott committed
166
167
168
        return models.build_model(args, self)

    def build_criterion(self, args):
Myle Ott's avatar
Myle Ott committed
169
170
171
172
173
174
175
176
177
178
        """
        Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
        this task.

        Args:
            args (argparse.Namespace): parsed command-line arguments

        Returns:
            a :class:`~fairseq.criterions.FairseqCriterion` instance
        """
Myle Ott's avatar
Myle Ott committed
179
        from fairseq import criterions
Myle Ott's avatar
Myle Ott committed
180
181
        return criterions.build_criterion(args, self)

Myle Ott's avatar
Myle Ott committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def build_generator(self, args):
        if args.score_reference:
            from fairseq.sequence_scorer import SequenceScorer
            return SequenceScorer(self.target_dictionary)
        else:
            from fairseq.sequence_generator import SequenceGenerator
            return SequenceGenerator(
                self.target_dictionary,
                beam_size=args.beam,
                max_len_a=args.max_len_a,
                max_len_b=args.max_len_b,
                min_len=args.min_len,
                stop_early=(not args.no_early_stop),
                normalize_scores=(not args.unnormalized),
                len_penalty=args.lenpen,
                unk_penalty=args.unkpen,
                sampling=args.sampling,
                sampling_topk=args.sampling_topk,
                sampling_temperature=args.sampling_temperature,
                diverse_beam_groups=args.diverse_beam_groups,
                diverse_beam_strength=args.diverse_beam_strength,
                match_source_len=args.match_source_len,
                no_repeat_ngram_size=args.no_repeat_ngram_size,
            )

Peng-Jen Chen's avatar
Peng-Jen Chen committed
207
    def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
Myle Ott's avatar
Myle Ott committed
208
        """
Peng-Jen Chen's avatar
Peng-Jen Chen committed
209
210
        Do forward and backward, and return the loss as computed by *criterion*
        for the given *model* and *sample*.
Myle Ott's avatar
Myle Ott committed
211
212
213
214

        Args:
            sample (dict): the mini-batch. The format is defined by the
                :class:`~fairseq.data.FairseqDataset`.
Peng-Jen Chen's avatar
Peng-Jen Chen committed
215
216
217
218
219
220
221
222
223
224
225
            model (~fairseq.models.BaseFairseqModel): the model
            criterion (~fairseq.criterions.FairseqCriterion): the criterion
            optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
            ignore_grad (bool): multiply loss by 0 if this is set to True

        Returns:
            tuple:
                - the loss
                - the sample size, which is used as the denominator for the
                  gradient
                - logging outputs to display while training
Myle Ott's avatar
Myle Ott committed
226
        """
Peng-Jen Chen's avatar
Peng-Jen Chen committed
227
228
229
230
231
232
233
234
235
236
237
238
239
        model.train()
        loss, sample_size, logging_output = criterion(model, sample)
        if ignore_grad:
            loss *= 0
        optimizer.backward(loss)
        return loss, sample_size, logging_output

    def valid_step(self, sample, model, criterion):
        model.eval()
        with torch.no_grad():
            loss, sample_size, logging_output = criterion(model, sample)
        return loss, sample_size, logging_output

Myle Ott's avatar
Myle Ott committed
240
241
242
    def inference_step(self, generator, models, sample, prefix_tokens=None):
        with torch.no_grad():
            return generator.generate(models, sample, prefix_tokens=prefix_tokens)
Peng-Jen Chen's avatar
Peng-Jen Chen committed
243
244
245
246
247

    def grad_denom(self, sample_sizes, criterion):
        return criterion.__class__.grad_denom(sample_sizes)

    def aggregate_logging_outputs(self, logging_outputs, criterion):
248
        return criterion.__class__.aggregate_logging_outputs(logging_outputs)
Myle Ott's avatar
Myle Ott committed
249

250
    def max_positions(self):
Myle Ott's avatar
Myle Ott committed
251
        """Return the max input length allowed by the task."""
252
253
        return None

Myle Ott's avatar
Myle Ott committed
254
255
    @property
    def source_dictionary(self):
Myle Ott's avatar
Myle Ott committed
256
257
        """Return the source :class:`~fairseq.data.Dictionary` (if applicable
        for this task)."""
Myle Ott's avatar
Myle Ott committed
258
259
260
261
        raise NotImplementedError

    @property
    def target_dictionary(self):
Myle Ott's avatar
Myle Ott committed
262
263
        """Return the target :class:`~fairseq.data.Dictionary` (if applicable
        for this task)."""
Myle Ott's avatar
Myle Ott committed
264
        raise NotImplementedError