train.py 11.8 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
#!/usr/bin/env python3
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import collections
import os
import torch
import math

15
from fairseq import data, options, utils
Sergey Edunov's avatar
Sergey Edunov committed
16
17
18
19
20
21
22
23
24
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer


def main():
    parser = options.get_parser('Trainer')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
                              help='maximum number of tokens in a batch')
25
26
    dataset_args.add_argument('--max-sentences', type=int, metavar='N',
                              help='maximum number of sentences in a batch')
Sergey Edunov's avatar
Sergey Edunov committed
27
28
29
30
    dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
                              choices=['train', 'valid', 'test'],
                              help='data subset to use for training (train, valid, test)')
    dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
toothlessdragon's avatar
toothlessdragon committed
31
                              help='comma separated list of data subsets '
Sergey Edunov's avatar
Sergey Edunov committed
32
                                   ' to use for validation (train, valid, valid1,test, test1)')
33
34
    dataset_args.add_argument('--max-sentences-valid', type=int, metavar='N',
                              help='maximum number of sentences in a validation batch')
Sergey Edunov's avatar
Sergey Edunov committed
35
36
37
38
    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_model_args(parser)

39
    args = utils.parse_args_and_arch(parser)
Sergey Edunov's avatar
Sergey Edunov committed
40

41
    if args.no_progress_bar and args.log_format is None:
42
        args.log_format = 'simple'
Sergey Edunov's avatar
Sergey Edunov committed
43

44
45
46
    if args.max_sentences_valid is None:
        args.max_sentences_valid = args.max_sentences

Sergey Edunov's avatar
Sergey Edunov committed
47
48
49
50
51
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    torch.manual_seed(args.seed)

    # Load dataset
52
53
54
55
56
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.source_lang, args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.source_lang, args.target_lang)
Sergey Edunov's avatar
Sergey Edunov committed
57
58
59
60
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst

61
62
63
64
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    args.num_gpus = torch.cuda.device_count()

Myle Ott's avatar
Myle Ott committed
65
    print(args)
Sergey Edunov's avatar
Sergey Edunov committed
66
67
    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
68
    for split in splits:
Sergey Edunov's avatar
Sergey Edunov committed
69
70
        print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))

71
    print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
72
        args.num_gpus, args.max_tokens, args.max_sentences))
Sergey Edunov's avatar
Sergey Edunov committed
73

74
    # Build model and criterion
75
76
    model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
    criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
77
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
78
    print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))
Sergey Edunov's avatar
Sergey Edunov committed
79

80
81
    # The max number of positions can be different for train and valid
    # e.g., RNNs may support more positions at test time than seen in training
82
    max_positions_train = (
Myle Ott's avatar
Myle Ott committed
83
84
85
        min(args.max_source_positions, model.max_encoder_positions()),
        min(args.max_target_positions, model.max_decoder_positions())
    )
86
    max_positions_valid = (model.max_encoder_positions(), model.max_decoder_positions())
87

Sergey Edunov's avatar
Sergey Edunov committed
88
    # Start multiprocessing
89
    trainer = MultiprocessingTrainer(args, model, criterion)
Sergey Edunov's avatar
Sergey Edunov committed
90
91

    # Load the latest checkpoint if one is available
92
93
94
95
96
97
98
99
100
101
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    extra_state = trainer.load_checkpoint(checkpoint_path)
    if extra_state is not None:
        epoch = extra_state['epoch']
        batch_offset = extra_state['batch_offset']
        print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
        if batch_offset == 0:
            epoch += 1
    else:
        epoch, batch_offset = 1, 0
Sergey Edunov's avatar
Sergey Edunov committed
102
103
104
105
106
107
108
109
110

    # Train until the learning rate gets too small
    val_loss = None
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
111
        train(args, epoch, batch_offset, trainer, dataset, max_positions_train)
Sergey Edunov's avatar
Sergey Edunov committed
112
113
114

        # evaluate on validate set
        for k, subset in enumerate(args.valid_subset.split(',')):
115
            val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset)
Sergey Edunov's avatar
Sergey Edunov committed
116
117
118
            if k == 0:
                if not args.no_save:
                    # save checkpoint
119
                    save_checkpoint(trainer, args, epoch, 0, val_loss)
Sergey Edunov's avatar
Sergey Edunov committed
120
121
122
123
124
125
126
127
128
129
130
131
                # only use first validation loss to update the learning schedule
                lr = trainer.lr_step(val_loss, epoch)

        epoch += 1
        batch_offset = 0
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))

    # Stop multiprocessing
    trainer.stop()


132
133
def get_perplexity(loss):
    try:
134
        return round(math.pow(2, loss), 2)
135
136
137
138
    except OverflowError:
        return float('inf')


139
def train(args, epoch, batch_offset, trainer, dataset, max_positions):
Sergey Edunov's avatar
Sergey Edunov committed
140
141
    """Train the model for one epoch."""

142
143
144
    seed = args.seed + epoch
    torch.manual_seed(seed)
    trainer.set_seed(seed)
Myle Ott's avatar
Myle Ott committed
145

Myle Ott's avatar
Myle Ott committed
146
    itr = dataset.train_dataloader(
147
148
        args.train_subset, num_workers=args.workers,
        max_tokens=args.max_tokens, max_sentences=args.max_sentences,
Myle Ott's avatar
Myle Ott committed
149
        max_positions=max_positions, seed=seed, epoch=epoch,
150
151
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum))
Sergey Edunov's avatar
Sergey Edunov committed
152
    loss_meter = AverageMeter()
153
    nll_loss_meter = AverageMeter()
Sergey Edunov's avatar
Sergey Edunov committed
154
155
156
157
    bsz_meter = AverageMeter()    # sentences per batch
    wpb_meter = AverageMeter()    # words per batch
    wps_meter = TimeMeter()       # words per second
    clip_meter = AverageMeter()   # % of updates clipped
Myle Ott's avatar
Myle Ott committed
158
    extra_meters = collections.defaultdict(lambda: AverageMeter())
Sergey Edunov's avatar
Sergey Edunov committed
159
160

    lr = trainer.get_lr()
161
    with utils.build_progress_bar(args, itr, epoch) as t:
162
        for i, sample in data.skip_group_enumerator(t, args.num_gpus, batch_offset):
Myle Ott's avatar
Myle Ott committed
163
164
165
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix
Sergey Edunov's avatar
Sergey Edunov committed
166
167

            ntokens = sum(s['ntokens'] for s in sample)
168
169
170
171
172

            if 'nll_loss' in loss_dict:
                nll_loss = loss_dict['nll_loss']
                nll_loss_meter.update(nll_loss, ntokens)

Myle Ott's avatar
Myle Ott committed
173
            nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample)
174
175
            loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
            bsz_meter.update(nsentences)
Sergey Edunov's avatar
Sergey Edunov committed
176
177
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
Myle Ott's avatar
Myle Ott committed
178
179
180
181
182
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
183
                extra_postfix.append((k, extra_meters[k].avg))
Sergey Edunov's avatar
Sergey Edunov committed
184

185
186
187
188
189
            t.log(collections.OrderedDict([
                ('loss', loss_meter),
                ('wps', round(wps_meter.avg)),
                ('wpb', round(wpb_meter.avg)),
                ('bsz', round(bsz_meter.avg)),
Sergey Edunov's avatar
Sergey Edunov committed
190
                ('lr', lr),
191
192
                ('clip', '{:.0%}'.format(clip_meter.avg)),
            ] + extra_postfix))
Sergey Edunov's avatar
Sergey Edunov committed
193
194
195
196
197

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
198
                save_checkpoint(trainer, args, epoch, i + 1)
Sergey Edunov's avatar
Sergey Edunov committed
199

200
201
        t.print(collections.OrderedDict([
            ('train loss', round(loss_meter.avg, 2)),
202
203
204
            ('train ppl', get_perplexity(nll_loss_meter.avg
                                         if nll_loss_meter.count > 0
                                         else loss_meter.avg)),
205
206
207
208
209
210
211
212
            ('s/checkpoint', round(wps_meter.elapsed_time)),
            ('words/s', round(wps_meter.avg)),
            ('words/batch', round(wpb_meter.avg)),
            ('bsz', round(bsz_meter.avg)),
            ('lr', lr),
            ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
        ] + [
            (k, meter.avg)
Myle Ott's avatar
Myle Ott committed
213
            for k, meter in extra_meters.items()
214
        ]))
Sergey Edunov's avatar
Sergey Edunov committed
215
216


217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
    extra_state = {
        'epoch': epoch,
        'batch_offset': batch_offset,
        'val_loss': val_loss,
    }

    if batch_offset == 0:
        if not args.no_epoch_checkpoints:
            epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
            trainer.save_checkpoint(epoch_filename, extra_state)

        assert val_loss is not None
        if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
            save_checkpoint.best = val_loss
            best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
            trainer.save_checkpoint(best_filename, extra_state)
234
235
236
237
    elif not args.no_epoch_checkpoints:
        epoch_filename = os.path.join(
            args.save_dir, 'checkpoint{}_{}.pt'.format(epoch, batch_offset))
        trainer.save_checkpoint(epoch_filename, extra_state)
238
239
240
241
242

    last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
    trainer.save_checkpoint(last_filename, extra_state)


243
def validate(args, epoch, trainer, dataset, max_positions, subset):
Sergey Edunov's avatar
Sergey Edunov committed
244
245
    """Evaluate the model on the validation set and return the average loss."""

Myle Ott's avatar
Myle Ott committed
246
    itr = dataset.eval_dataloader(
247
        subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid,
248
        max_positions=max_positions,
249
250
251
        skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
    )
Sergey Edunov's avatar
Sergey Edunov committed
252
    loss_meter = AverageMeter()
253
    nll_loss_meter = AverageMeter()
Myle Ott's avatar
Myle Ott committed
254
    extra_meters = collections.defaultdict(lambda: AverageMeter())
Sergey Edunov's avatar
Sergey Edunov committed
255

256
257
    prefix = 'valid on \'{}\' subset'.format(subset)
    with utils.build_progress_bar(args, itr, epoch, prefix) as t:
258
        for _, sample in data.skip_group_enumerator(t, args.num_gpus):
Myle Ott's avatar
Myle Ott committed
259
            loss_dict = trainer.valid_step(sample)
260
            ntokens = sum(s['ntokens'] for s in sample)
Myle Ott's avatar
Myle Ott committed
261
262
263
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

264
265
266
267
            if 'nll_loss' in loss_dict:
                nll_loss = loss_dict['nll_loss']
                nll_loss_meter.update(nll_loss, ntokens)

Sergey Edunov's avatar
Sergey Edunov committed
268
            loss_meter.update(loss, ntokens)
Myle Ott's avatar
Myle Ott committed
269
270
271
272

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
273
                extra_postfix.append((k, extra_meters[k].avg))
Myle Ott's avatar
Myle Ott committed
274

275
276
277
            t.log(collections.OrderedDict([
                ('valid loss', round(loss_meter.avg, 2)),
            ] + extra_postfix))
Sergey Edunov's avatar
Sergey Edunov committed
278

279
280
        t.print(collections.OrderedDict([
            ('valid loss', round(loss_meter.avg, 2)),
281
282
283
            ('valid ppl', get_perplexity(nll_loss_meter.avg
                                         if nll_loss_meter.count > 0
                                         else loss_meter.avg)),
284
285
        ] + [
            (k, meter.avg)
Myle Ott's avatar
Myle Ott committed
286
            for k, meter in extra_meters.items()
287
        ]))
Sergey Edunov's avatar
Sergey Edunov committed
288
289

    # update and return the learning rate
290
    return loss_meter.avg
Sergey Edunov's avatar
Sergey Edunov committed
291
292
293
294


if __name__ == '__main__':
    main()