fairseq_task.py 9.86 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Peng-Jen Chen's avatar
Peng-Jen Chen committed
8
import torch
9

10
11
12
from fairseq import tokenizer
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary

Myle Ott's avatar
Myle Ott committed
13
14
15

class FairseqTask(object):
    """
16
17
    Tasks store dictionaries and provide helpers for loading/iterating over
    Datasets, initializing the Model/Criterion and calculating the loss.
Myle Ott's avatar
Myle Ott committed
18
19
20
21
22
23
24
25
26
27
28
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        pass

    def __init__(self, args):
        self.args = args
        self.datasets = {}

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    @classmethod
    def load_dictionary(cls, filename):
        """Load the dictionary from the filename

        Args:
            filename (str): the filename
        """
        return Dictionary.load(filename)

    @classmethod
    def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
        """Build the dictionary

        Args:
            filenames (list): list of filenames
            workers (int): number of concurrent workers
            threshold (int): defines the minimum word count
            nwords (int): defines the total number of words in the final dictionary,
                including special symbols
            padding_factor (int): can be used to pad the dictionary size to be a
                multiple of 8, which is important on some hardware (e.g., Nvidia
                Tensor Cores).
        """
        d = Dictionary()
        for filename in filenames:
54
            Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
55
56
57
        d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
        return d

Myle Ott's avatar
Myle Ott committed
58
59
    @classmethod
    def setup_task(cls, args, **kwargs):
Myle Ott's avatar
Myle Ott committed
60
61
62
63
64
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
65
        return cls(args, **kwargs)
Myle Ott's avatar
Myle Ott committed
66

Peng-Jen Chen's avatar
Peng-Jen Chen committed
67
    def load_dataset(self, split, combine=False, **kwargs):
Myle Ott's avatar
Myle Ott committed
68
69
70
71
72
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
Myle Ott's avatar
Myle Ott committed
73
74
75
        raise NotImplementedError

    def dataset(self, split):
Myle Ott's avatar
Myle Ott committed
76
77
78
79
80
81
82
83
84
        """
        Return a loaded dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)

        Returns:
            a :class:`~fairseq.data.FairseqDataset` corresponding to *split*
        """
Myle Ott's avatar
Myle Ott committed
85
        from fairseq.data import FairseqDataset
Myle Ott's avatar
Myle Ott committed
86
87
88
89
90
91
        if split not in self.datasets:
            raise KeyError('Dataset not loaded: ' + split)
        if not isinstance(self.datasets[split], FairseqDataset):
            raise TypeError('Datasets are expected to be of type FairseqDataset')
        return self.datasets[split]

92
    def get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
93
94
95
        self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
        ignore_invalid_inputs=False, required_batch_size_multiple=1,
        seed=1, num_shards=1, shard_id=0, num_workers=0,
96
97
    ):
        """
Myle Ott's avatar
Myle Ott committed
98
        Get an iterator that yields batches of data from the given dataset.
99
100

        Args:
Myle Ott's avatar
Myle Ott committed
101
            dataset (~fairseq.data.FairseqDataset): dataset to batch
Myle Ott's avatar
Myle Ott committed
102
103
            max_tokens (int, optional): max number of tokens in each batch
                (default: None).
104
            max_sentences (int, optional): max number of sentences in each
Myle Ott's avatar
Myle Ott committed
105
                batch (default: None).
106
            max_positions (optional): max sentence length supported by the
Myle Ott's avatar
Myle Ott committed
107
                model (default: None).
108
            ignore_invalid_inputs (bool, optional): don't raise Exception for
Myle Ott's avatar
Myle Ott committed
109
                sentences that are too long (default: False).
110
            required_batch_size_multiple (int, optional): require batch size to
Myle Ott's avatar
Myle Ott committed
111
                be a multiple of N (default: 1).
112
            seed (int, optional): seed for random number generator for
Myle Ott's avatar
Myle Ott committed
113
                reproducibility (default: 1).
114
            num_shards (int, optional): shard the data iterator into N
Myle Ott's avatar
Myle Ott committed
115
                shards (default: 1).
116
            shard_id (int, optional): which shard of the data iterator to
Myle Ott's avatar
Myle Ott committed
117
118
119
120
                return (default: 0).
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means the data will be loaded in the main process
                (default: 0).
121
122

        Returns:
Myle Ott's avatar
Myle Ott committed
123
124
            ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        """
        assert isinstance(dataset, FairseqDataset)

        # get indices ordered by example size
        with data_utils.numpy_seed(seed):
            indices = dataset.ordered_indices()

        # filter examples that are too large
        indices = data_utils.filter_by_size(
            indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs),
        )

        # create mini-batches with given size constraints
        batch_sampler = data_utils.batch_by_size(
            indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
        )

        # return a reusable, sharded iterator
144
        return iterators.EpochBatchIterator(
145
            dataset=dataset,
146
            collate_fn=dataset.collater,
147
148
149
150
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
Myle Ott's avatar
Myle Ott committed
151
            num_workers=num_workers,
152
153
        )

Myle Ott's avatar
Myle Ott committed
154
    def build_model(self, args):
Myle Ott's avatar
Myle Ott committed
155
156
157
158
159
160
161
162
163
164
        """
        Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
        task.

        Args:
            args (argparse.Namespace): parsed command-line arguments

        Returns:
            a :class:`~fairseq.models.BaseFairseqModel` instance
        """
Myle Ott's avatar
Myle Ott committed
165
        from fairseq import models
Myle Ott's avatar
Myle Ott committed
166
167
168
        return models.build_model(args, self)

    def build_criterion(self, args):
Myle Ott's avatar
Myle Ott committed
169
170
171
172
173
174
175
176
177
178
        """
        Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
        this task.

        Args:
            args (argparse.Namespace): parsed command-line arguments

        Returns:
            a :class:`~fairseq.criterions.FairseqCriterion` instance
        """
Myle Ott's avatar
Myle Ott committed
179
        from fairseq import criterions
Myle Ott's avatar
Myle Ott committed
180
181
        return criterions.build_criterion(args, self)

Myle Ott's avatar
Myle Ott committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    def build_generator(self, args):
        if args.score_reference:
            from fairseq.sequence_scorer import SequenceScorer
            return SequenceScorer(self.target_dictionary)
        else:
            from fairseq.sequence_generator import SequenceGenerator
            return SequenceGenerator(
                self.target_dictionary,
                beam_size=args.beam,
                max_len_a=args.max_len_a,
                max_len_b=args.max_len_b,
                min_len=args.min_len,
                stop_early=(not args.no_early_stop),
                normalize_scores=(not args.unnormalized),
                len_penalty=args.lenpen,
                unk_penalty=args.unkpen,
                sampling=args.sampling,
                sampling_topk=args.sampling_topk,
200
                temperature=args.temperature,
Myle Ott's avatar
Myle Ott committed
201
202
203
204
205
206
                diverse_beam_groups=args.diverse_beam_groups,
                diverse_beam_strength=args.diverse_beam_strength,
                match_source_len=args.match_source_len,
                no_repeat_ngram_size=args.no_repeat_ngram_size,
            )

Peng-Jen Chen's avatar
Peng-Jen Chen committed
207
    def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
Myle Ott's avatar
Myle Ott committed
208
        """
Peng-Jen Chen's avatar
Peng-Jen Chen committed
209
210
        Do forward and backward, and return the loss as computed by *criterion*
        for the given *model* and *sample*.
Myle Ott's avatar
Myle Ott committed
211
212
213
214

        Args:
            sample (dict): the mini-batch. The format is defined by the
                :class:`~fairseq.data.FairseqDataset`.
Peng-Jen Chen's avatar
Peng-Jen Chen committed
215
216
217
218
219
220
221
222
223
224
225
            model (~fairseq.models.BaseFairseqModel): the model
            criterion (~fairseq.criterions.FairseqCriterion): the criterion
            optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
            ignore_grad (bool): multiply loss by 0 if this is set to True

        Returns:
            tuple:
                - the loss
                - the sample size, which is used as the denominator for the
                  gradient
                - logging outputs to display while training
Myle Ott's avatar
Myle Ott committed
226
        """
Peng-Jen Chen's avatar
Peng-Jen Chen committed
227
228
229
230
231
232
233
234
235
236
237
238
239
        model.train()
        loss, sample_size, logging_output = criterion(model, sample)
        if ignore_grad:
            loss *= 0
        optimizer.backward(loss)
        return loss, sample_size, logging_output

    def valid_step(self, sample, model, criterion):
        model.eval()
        with torch.no_grad():
            loss, sample_size, logging_output = criterion(model, sample)
        return loss, sample_size, logging_output

Myle Ott's avatar
Myle Ott committed
240
241
242
    def inference_step(self, generator, models, sample, prefix_tokens=None):
        with torch.no_grad():
            return generator.generate(models, sample, prefix_tokens=prefix_tokens)
Peng-Jen Chen's avatar
Peng-Jen Chen committed
243

244
245
246
247
248
    def update_step(self, num_updates):
        """Task level update when number of update increases. This is called after optimization step and
           learning rate update of each step"""
        pass

Peng-Jen Chen's avatar
Peng-Jen Chen committed
249
250
251
252
    def grad_denom(self, sample_sizes, criterion):
        return criterion.__class__.grad_denom(sample_sizes)

    def aggregate_logging_outputs(self, logging_outputs, criterion):
253
        return criterion.__class__.aggregate_logging_outputs(logging_outputs)
Myle Ott's avatar
Myle Ott committed
254

255
    def max_positions(self):
Myle Ott's avatar
Myle Ott committed
256
        """Return the max input length allowed by the task."""
257
258
        return None

Myle Ott's avatar
Myle Ott committed
259
260
    @property
    def source_dictionary(self):
Myle Ott's avatar
Myle Ott committed
261
262
        """Return the source :class:`~fairseq.data.Dictionary` (if applicable
        for this task)."""
Myle Ott's avatar
Myle Ott committed
263
264
265
266
        raise NotImplementedError

    @property
    def target_dictionary(self):
Myle Ott's avatar
Myle Ott committed
267
268
        """Return the target :class:`~fairseq.data.Dictionary` (if applicable
        for this task)."""
Myle Ott's avatar
Myle Ott committed
269
        raise NotImplementedError