train.py 15.5 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Train a new model on one or across multiple GPUs.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

12
import collections
Myle Ott's avatar
Myle Ott committed
13
import itertools
14
import math
Myle Ott's avatar
Myle Ott committed
15
import os
16
17
import random

18
import torch
Sergey Edunov's avatar
Sergey Edunov committed
19

Myle Ott's avatar
Myle Ott committed
20
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
21
from fairseq.data import iterators
22
23
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
Sergey Edunov's avatar
Sergey Edunov committed
24

Myle Ott's avatar
Myle Ott committed
25

26
def main(args, init_distributed=False):
Myle Ott's avatar
Myle Ott committed
27
    utils.import_user_module(args)
28

29
30
    if args.max_tokens is None:
        args.max_tokens = 6000
31
32
    print(args)

Myle Ott's avatar
Myle Ott committed
33
34
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
35
36
    torch.manual_seed(args.seed)

Myle Ott's avatar
Myle Ott committed
37
38
    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
39

Myle Ott's avatar
Myle Ott committed
40
    # Load dataset splits
41
    load_dataset_splits(args, task)
42

Myle Ott's avatar
Myle Ott committed
43
44
45
    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
46
    print(model)
47
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
48
49
50
51
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
52

53
54
55
56
57
58
59
    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
60
61
    dummy_batch = task.dataset(args.train_subset).get_dummy_batch(args.max_tokens, max_positions)
    oom_batch = task.dataset(args.train_subset).get_dummy_batch(1, max_positions)
62

63
    # Build trainer
64
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
65
66
67
68
69
70
71
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
72
    epoch_itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
73
        dataset=task.dataset(args.train_subset),
74
        max_tokens=args.max_tokens,
75
        max_sentences=args.max_sentences,
Myle Ott's avatar
Myle Ott committed
76
77
        max_positions=max_positions,
        ignore_invalid_inputs=True,
78
        required_batch_size_multiple=args.required_batch_size_multiple,
79
80
        seed=args.seed,
        num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
81
        shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
82
        num_workers=args.num_workers,
83
84
    )

85
86
87
88
89
90
    # Initialize distributed training (after data loading)
    if init_distributed:
        import socket
        args.distributed_rank = distributed_utils.distributed_init(args)
        print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))

91
    # Load the latest checkpoint if one is available
92
    if not load_checkpoint(args, trainer, epoch_itr):
93
        trainer.dummy_train_step([dummy_batch])
94
95
96
97
98
99
100

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
Myle Ott's avatar
Myle Ott committed
101
    valid_losses = [None]
102
    valid_subsets = args.valid_subset.split(',')
Myle Ott's avatar
Myle Ott committed
103
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
104
        # train for one epoch
Myle Ott's avatar
Myle Ott committed
105
        train(args, trainer, task, epoch_itr)
106

Myle Ott's avatar
Myle Ott committed
107
108
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
109
110

        # only use first validation loss to update the learning rate
Myle Ott's avatar
Myle Ott committed
111
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
112
113

        # save checkpoint
Myle Ott's avatar
Myle Ott committed
114
115
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
116
117
118
119
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))


Myle Ott's avatar
Myle Ott committed
120
def train(args, trainer, task, epoch_itr):
121
    """Train the model for one epoch."""
122
    # Update parameters every N batches
Myle Ott's avatar
Myle Ott committed
123
124
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
Myle Ott's avatar
Myle Ott committed
125
126
127
128
129
130

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
131
132
133
134
135
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

136
    extra_meters = collections.defaultdict(lambda: AverageMeter())
137
    first_valid = args.valid_subset.split(',')[0]
138
    max_update = args.max_update or math.inf
139
140
141
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
142
143
144
145
146
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
147
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
148
149
150
151
152
153
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
Myle Ott's avatar
Myle Ott committed
154
        progress.log(stats, tag='train', step=stats['num_updates'])
155
156
157
158
159

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

160
        num_updates = trainer.get_num_updates()
161
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
Myle Ott's avatar
Myle Ott committed
162
163
            valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
164
165

        if num_updates >= max_update:
166
167
168
169
170
171
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
172
    progress.print(stats, tag='train', step=stats['num_updates'])
173

Myle Ott's avatar
Myle Ott committed
174
    # reset training meters
175
176
177
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
Myle Ott's avatar
Myle Ott committed
178
179
180
181
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

182
183
184

def get_training_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
185
    stats['loss'] = trainer.get_meter('train_loss')
186
    if trainer.get_meter('train_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
187
188
        nll_loss = trainer.get_meter('train_nll_loss')
        stats['nll_loss'] = nll_loss
189
    else:
Myle Ott's avatar
Myle Ott committed
190
191
192
193
194
195
        nll_loss = trainer.get_meter('train_loss')
    stats['ppl'] = get_perplexity(nll_loss.avg)
    stats['wps'] = trainer.get_meter('wps')
    stats['ups'] = trainer.get_meter('ups')
    stats['wpb'] = trainer.get_meter('wpb')
    stats['bsz'] = trainer.get_meter('bsz')
196
197
    stats['num_updates'] = trainer.get_num_updates()
    stats['lr'] = trainer.get_lr()
Myle Ott's avatar
Myle Ott committed
198
199
200
    stats['gnorm'] = trainer.get_meter('gnorm')
    stats['clip'] = trainer.get_meter('clip')
    stats['oom'] = trainer.get_meter('oom')
201
    if trainer.get_meter('loss_scale') is not None:
Myle Ott's avatar
Myle Ott committed
202
        stats['loss_scale'] = trainer.get_meter('loss_scale')
203
    stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
Myle Ott's avatar
Myle Ott committed
204
    stats['train_wall'] = trainer.get_meter('train_wall')
205
206
207
    return stats


Myle Ott's avatar
Myle Ott committed
208
def validate(args, trainer, task, epoch_itr, subsets):
209
210
211
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
Myle Ott's avatar
Myle Ott committed
212
        # Initialize data iterator
213
        itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
214
            dataset=task.dataset(subset),
215
216
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
217
218
219
220
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
Myle Ott's avatar
Myle Ott committed
221
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
222
            required_batch_size_multiple=args.required_batch_size_multiple,
Myle Ott's avatar
Myle Ott committed
223
            seed=args.seed,
224
            num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
225
            shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
226
            num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
227
        ).next_epoch_itr(shuffle=False)
228
        progress = progress_bar.build_progress_bar(
Myle Ott's avatar
Myle Ott committed
229
            args, itr, epoch_itr.epoch,
230
231
232
233
234
235
236
237
238
239
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
Myle Ott's avatar
Myle Ott committed
240

241
242
243
244
        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
245
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
246
247
                    continue
                extra_meters[k].update(v)
248

249
250
251
252
        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
253
        progress.print(stats, tag=subset, step=trainer.get_num_updates())
254

Myle Ott's avatar
Myle Ott committed
255
        valid_losses.append(stats['loss'].avg)
256
    return valid_losses
257
258
259
260


def get_valid_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
261
    stats['loss'] = trainer.get_meter('valid_loss')
262
    if trainer.get_meter('valid_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
263
264
        nll_loss = trainer.get_meter('valid_nll_loss')
        stats['nll_loss'] = nll_loss
265
    else:
Myle Ott's avatar
Myle Ott committed
266
267
        nll_loss = stats['loss']
    stats['ppl'] = get_perplexity(nll_loss.avg)
Myle Ott's avatar
Nits  
Myle Ott committed
268
269
    stats['num_updates'] = trainer.get_num_updates()
    if hasattr(save_checkpoint, 'best'):
Myle Ott's avatar
Myle Ott committed
270
        stats['best_loss'] = min(save_checkpoint.best, stats['loss'].avg)
271
272
273
274
275
276
277
278
279
280
    return stats


def get_perplexity(loss):
    try:
        return '{:.2f}'.format(math.pow(2, loss))
    except OverflowError:
        return float('inf')


Myle Ott's avatar
Myle Ott committed
281
282
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
283
        return
Myle Ott's avatar
Myle Ott committed
284
285
286
287

    write_timer = StopwatchMeter()
    write_timer.start()

Myle Ott's avatar
Myle Ott committed
288
289
    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
290
291
292
293
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
Alexei Baevski's avatar
Alexei Baevski committed
294
295
            end_of_epoch and not args.no_epoch_checkpoints and
            epoch % args.save_interval == 0
296
297
    )
    checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
Alexei Baevski's avatar
Alexei Baevski committed
298
299
            not end_of_epoch and args.save_interval_updates > 0 and
            updates % args.save_interval_updates == 0
300
301
    )
    checkpoint_conds['checkpoint_best.pt'] = (
Alexei Baevski's avatar
Alexei Baevski committed
302
303
            val_loss is not None and
            (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
304
305
306
    )
    checkpoint_conds['checkpoint_last.pt'] = True  # keep this last so that it's a symlink

Myle Ott's avatar
Myle Ott committed
307
308
309
    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
310
    extra_state = {
Myle Ott's avatar
Myle Ott committed
311
        'train_iterator': epoch_itr.state_dict(),
312
313
        'val_loss': val_loss,
    }
Naman Goyal's avatar
Naman Goyal committed
314
315
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})
316

317
318
    checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
    if len(checkpoints) > 0:
319
320
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)
321

freewym's avatar
freewym committed
322
323
324
325
        write_timer.stop()
        print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
            checkpoints[0], epoch, updates, write_timer.sum))

326
327
    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
Myle Ott's avatar
Myle Ott committed
328
329
330
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
        )
331
        for old_chk in checkpoints[args.keep_interval_updates:]:
Myle Ott's avatar
Myle Ott committed
332
333
334
335
336
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
Myle Ott's avatar
Myle Ott committed
337
338
339
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint(\d+)\.pt',
        )
Myle Ott's avatar
Myle Ott committed
340
341
342
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
343
344


Myle Ott's avatar
Myle Ott committed
345
346
def load_checkpoint(args, trainer, epoch_itr):
    """Load a checkpoint and replay dataloader to match."""
347
348
349
350
351

    # Only rank 0 should attempt to create the required dir
    if args.distributed_rank == 0:
        os.makedirs(args.save_dir, exist_ok=True)

352
353
354
355
    if os.path.isabs(args.restore_file):
        checkpoint_path = args.restore_file
    else:
        checkpoint_path = os.path.join(args.save_dir, args.restore_file)
356
    if os.path.isfile(checkpoint_path):
357
358
        extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
                                              eval(args.optimizer_overrides))
359
        if extra_state is not None:
Myle Ott's avatar
Myle Ott committed
360
361
362
363
364
            # replay train iterator to match checkpoint
            epoch_itr.load_state_dict(extra_state['train_iterator'])

            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
alexeib's avatar
alexeib committed
365

Myle Ott's avatar
Myle Ott committed
366
367
            trainer.lr_step(epoch_itr.epoch)
            trainer.lr_step_update(trainer.get_num_updates())
368
            if 'best' in extra_state and not args.reset_optimizer:
369
                save_checkpoint.best = extra_state['best']
370
        return True
371
372
    else:
        print('| no existing checkpoint found {}'.format(checkpoint_path))
373
    return False
374

375

376
377
378
379
380
381
382
383
384
385
386
def load_dataset_splits(args, task):
    task.load_dataset(args.train_subset, combine=True)
    for split in args.valid_subset.split(','):
        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            try:
                task.load_dataset(split_k, combine=False)
            except FileNotFoundError as e:
                if k > 0:
                    break
                raise e
Sergey Edunov's avatar
Sergey Edunov committed
387

Myle Ott's avatar
Myle Ott committed
388

Myle Ott's avatar
Myle Ott committed
389
390
391
392
def distributed_main(i, args):
    args.device_id = i
    if args.distributed_rank is None:  # torch.multiprocessing.spawn
        args.distributed_rank = i
393
    main(args, init_distributed=True)
Myle Ott's avatar
Myle Ott committed
394
395


Myle Ott's avatar
Myle Ott committed
396
def cli_main():
Myle Ott's avatar
Myle Ott committed
397
398
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
399

Myle Ott's avatar
Myle Ott committed
400
401
    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)
402

Myle Ott's avatar
Myle Ott committed
403
404
405
    if args.distributed_init_method is not None:
        # distributed training
        distributed_main(args.device_id, args)
406
    elif args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
407
        # fallback for single node with multiple GPUs
408
409
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
Myle Ott's avatar
Myle Ott committed
410
        args.distributed_rank = None  # set based on device id
Myle Ott's avatar
Myle Ott committed
411
412
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
Myle Ott's avatar
Myle Ott committed
413
414
415
416
417
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
418
    else:
Myle Ott's avatar
Myle Ott committed
419
        # single GPU training
420
        main(args)
Myle Ott's avatar
Myle Ott committed
421
422
423
424


if __name__ == '__main__':
    cli_main()