train.py 16.6 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Train a new model on one or across multiple GPUs.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

12
13
import collections
import math
Myle Ott's avatar
Myle Ott committed
14
import os
15
import random
Myle Ott's avatar
Myle Ott committed
16
import shutil
17

18
import torch
Sergey Edunov's avatar
Sergey Edunov committed
19

Myle Ott's avatar
Myle Ott committed
20
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
21
from fairseq.data import iterators
22
23
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
Sergey Edunov's avatar
Sergey Edunov committed
24

Myle Ott's avatar
Myle Ott committed
25

26
def main(args, init_distributed=False):
Myle Ott's avatar
Myle Ott committed
27
    utils.import_user_module(args)
28

29
30
    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
31

32
    # Initialize CUDA and distributed training
Myle Ott's avatar
Myle Ott committed
33
34
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
35
    torch.manual_seed(args.seed)
36
37
38
39
40
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    # Print args
    print(args)
41

Myle Ott's avatar
Myle Ott committed
42
43
    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
44

Myle Ott's avatar
Myle Ott committed
45
    # Load dataset splits
Naman Goyal's avatar
Naman Goyal committed
46
47
48
    task.load_dataset(args.train_subset, combine=True, epoch=0)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=True, epoch=0)
49

Myle Ott's avatar
Myle Ott committed
50
51
52
    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
53
    print(model)
54
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
55
56
57
58
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
59
60

    # Build trainer
Myle Ott's avatar
Myle Ott committed
61
    trainer = Trainer(args, task, model, criterion)
62
63
64
65
66
67
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

Naman Goyal's avatar
Naman Goyal committed
68
69
70
71
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
72
    # Initialize dataloader
73
    epoch_itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
74
        dataset=task.dataset(args.train_subset),
75
        max_tokens=args.max_tokens,
76
        max_sentences=args.max_sentences,
Naman Goyal's avatar
Naman Goyal committed
77
        max_positions=max_positions,
Myle Ott's avatar
Myle Ott committed
78
        ignore_invalid_inputs=True,
79
        required_batch_size_multiple=args.required_batch_size_multiple,
80
81
        seed=args.seed,
        num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
82
        shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
83
        num_workers=args.num_workers,
84
85
86
    )

    # Load the latest checkpoint if one is available
Naman Goyal's avatar
Naman Goyal committed
87
    load_checkpoint(args, trainer, epoch_itr, max_positions, task)
88
89
90
91
92
93
94

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
Myle Ott's avatar
Myle Ott committed
95
    valid_losses = [None]
96
    valid_subsets = args.valid_subset.split(',')
Myle Ott's avatar
Myle Ott committed
97
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
98
        # train for one epoch
Myle Ott's avatar
Myle Ott committed
99
        train(args, trainer, task, epoch_itr)
100

Myle Ott's avatar
Myle Ott committed
101
102
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
103
104

        # only use first validation loss to update the learning rate
Myle Ott's avatar
Myle Ott committed
105
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
106
107

        # save checkpoint
Myle Ott's avatar
Myle Ott committed
108
109
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
Naman Goyal's avatar
Naman Goyal committed
110
111

        epoch_itr = reload_train(args, epoch_itr, max_positions, task)
112
113
114
115
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))


Naman Goyal's avatar
Naman Goyal committed
116
117
def reload_train(args, epoch_itr, max_positions, task):
    # nothing needs to be done when the dataset is not sharded.
Jay Mahadeokar's avatar
Jay Mahadeokar committed
118
    if "data" not in args or ("data" in args and len(args.data.split(":")) == 1):
Naman Goyal's avatar
Naman Goyal committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        return epoch_itr
    print("| Reloading shard of train data at epoch: ", epoch_itr.epoch)
    task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch)
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
        epoch=epoch_itr.epoch,
    )
    return epoch_itr


Myle Ott's avatar
Myle Ott committed
138
def train(args, trainer, task, epoch_itr):
139
    """Train the model for one epoch."""
140
    # Update parameters every N batches
Myle Ott's avatar
Myle Ott committed
141
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
Myle Ott's avatar
Myle Ott committed
142
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
Myle Ott's avatar
Myle Ott committed
143
144
145
146
147
148

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
149
150
151
152
153
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

154
    extra_meters = collections.defaultdict(lambda: AverageMeter())
155
    valid_subsets = args.valid_subset.split(',')
156
    max_update = args.max_update or math.inf
157
158
159
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
160
161
162
163
164
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
165
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
166
167
168
169
170
171
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
Myle Ott's avatar
Myle Ott committed
172
        progress.log(stats, tag='train', step=stats['num_updates'])
173
174
175
176
177

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

178
        num_updates = trainer.get_num_updates()
179
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
180
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
Myle Ott's avatar
Myle Ott committed
181
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
182
183

        if num_updates >= max_update:
184
185
186
187
188
189
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
190
    progress.print(stats, tag='train', step=stats['num_updates'])
191

Myle Ott's avatar
Myle Ott committed
192
    # reset training meters
193
194
195
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
Myle Ott's avatar
Myle Ott committed
196
197
198
199
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

200
201
202

def get_training_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
203
    stats['loss'] = trainer.get_meter('train_loss')
204
    if trainer.get_meter('train_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
205
206
        nll_loss = trainer.get_meter('train_nll_loss')
        stats['nll_loss'] = nll_loss
207
    else:
Myle Ott's avatar
Myle Ott committed
208
209
210
211
212
213
        nll_loss = trainer.get_meter('train_loss')
    stats['ppl'] = get_perplexity(nll_loss.avg)
    stats['wps'] = trainer.get_meter('wps')
    stats['ups'] = trainer.get_meter('ups')
    stats['wpb'] = trainer.get_meter('wpb')
    stats['bsz'] = trainer.get_meter('bsz')
214
215
    stats['num_updates'] = trainer.get_num_updates()
    stats['lr'] = trainer.get_lr()
Myle Ott's avatar
Myle Ott committed
216
217
218
    stats['gnorm'] = trainer.get_meter('gnorm')
    stats['clip'] = trainer.get_meter('clip')
    stats['oom'] = trainer.get_meter('oom')
219
    if trainer.get_meter('loss_scale') is not None:
Myle Ott's avatar
Myle Ott committed
220
        stats['loss_scale'] = trainer.get_meter('loss_scale')
221
    stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
Myle Ott's avatar
Myle Ott committed
222
    stats['train_wall'] = trainer.get_meter('train_wall')
223
224
225
    return stats


Myle Ott's avatar
Myle Ott committed
226
def validate(args, trainer, task, epoch_itr, subsets):
227
228
229
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
Myle Ott's avatar
Myle Ott committed
230
        # Initialize data iterator
231
        itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
232
            dataset=task.dataset(subset),
233
234
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
235
236
237
238
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
Myle Ott's avatar
Myle Ott committed
239
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
240
            required_batch_size_multiple=args.required_batch_size_multiple,
Myle Ott's avatar
Myle Ott committed
241
            seed=args.seed,
242
            num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
243
            shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
244
            num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
245
        ).next_epoch_itr(shuffle=False)
246
        progress = progress_bar.build_progress_bar(
Myle Ott's avatar
Myle Ott committed
247
            args, itr, epoch_itr.epoch,
248
249
250
251
252
253
254
255
256
257
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
Myle Ott's avatar
Myle Ott committed
258

259
260
261
262
        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
263
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
264
265
                    continue
                extra_meters[k].update(v)
266

267
268
269
270
        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
271
        progress.print(stats, tag=subset, step=trainer.get_num_updates())
272

Myle Ott's avatar
Myle Ott committed
273
        valid_losses.append(stats['loss'].avg)
274
    return valid_losses
275
276
277
278


def get_valid_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
279
    stats['loss'] = trainer.get_meter('valid_loss')
280
    if trainer.get_meter('valid_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
281
282
        nll_loss = trainer.get_meter('valid_nll_loss')
        stats['nll_loss'] = nll_loss
283
    else:
Myle Ott's avatar
Myle Ott committed
284
285
        nll_loss = stats['loss']
    stats['ppl'] = get_perplexity(nll_loss.avg)
Myle Ott's avatar
Nits  
Myle Ott committed
286
287
    stats['num_updates'] = trainer.get_num_updates()
    if hasattr(save_checkpoint, 'best'):
Myle Ott's avatar
Myle Ott committed
288
        stats['best_loss'] = min(save_checkpoint.best, stats['loss'].avg)
289
290
291
292
293
294
295
296
297
298
    return stats


def get_perplexity(loss):
    try:
        return '{:.2f}'.format(math.pow(2, loss))
    except OverflowError:
        return float('inf')


Myle Ott's avatar
Myle Ott committed
299
300
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
301
        return
Myle Ott's avatar
Myle Ott committed
302
303
304
305

    write_timer = StopwatchMeter()
    write_timer.start()

Myle Ott's avatar
Myle Ott committed
306
307
    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
308
309
310
311
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
Myle Ott's avatar
Myle Ott committed
312
313
        end_of_epoch and not args.no_epoch_checkpoints and
        epoch % args.save_interval == 0
314
315
    )
    checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
Myle Ott's avatar
Myle Ott committed
316
317
        not end_of_epoch and args.save_interval_updates > 0 and
        updates % args.save_interval_updates == 0
318
319
    )
    checkpoint_conds['checkpoint_best.pt'] = (
Myle Ott's avatar
Myle Ott committed
320
321
        val_loss is not None and
        (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
322
323
324
    )
    checkpoint_conds['checkpoint_last.pt'] = True  # keep this last so that it's a symlink

Myle Ott's avatar
Myle Ott committed
325
326
327
    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
328
    extra_state = {
Myle Ott's avatar
Myle Ott committed
329
        'train_iterator': epoch_itr.state_dict(),
330
331
        'val_loss': val_loss,
    }
Naman Goyal's avatar
Naman Goyal committed
332
333
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})
334

335
336
    checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
    if len(checkpoints) > 0:
Myle Ott's avatar
Myle Ott committed
337
338
339
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            shutil.copyfile(checkpoints[0], cp)
340

freewym's avatar
freewym committed
341
342
343
344
        write_timer.stop()
        print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
            checkpoints[0], epoch, updates, write_timer.sum))

345
346
    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
Myle Ott's avatar
Myle Ott committed
347
348
349
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
        )
350
        for old_chk in checkpoints[args.keep_interval_updates:]:
Myle Ott's avatar
Myle Ott committed
351
352
353
354
355
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
Myle Ott's avatar
Myle Ott committed
356
357
358
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint(\d+)\.pt',
        )
Myle Ott's avatar
Myle Ott committed
359
360
361
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
362
363


Naman Goyal's avatar
Naman Goyal committed
364
def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
Myle Ott's avatar
Myle Ott committed
365
    """Load a checkpoint and replay dataloader to match."""
366
367
368
369
    # Only rank 0 should attempt to create the required dir
    if args.distributed_rank == 0:
        os.makedirs(args.save_dir, exist_ok=True)

370
371
372
373
    if os.path.isabs(args.restore_file):
        checkpoint_path = args.restore_file
    else:
        checkpoint_path = os.path.join(args.save_dir, args.restore_file)
374
    if os.path.isfile(checkpoint_path):
375
376
        extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
                                              eval(args.optimizer_overrides))
377
        if extra_state is not None:
Myle Ott's avatar
Myle Ott committed
378
            # replay train iterator to match checkpoint
Naman Goyal's avatar
Naman Goyal committed
379
380
381
382
383
384
385
386
            epoch_itr_state = extra_state['train_iterator']

            # If the loaded checkpoint is not at epoch 0, reload train dataset,
            # as it could be potentially sharded.
            if epoch_itr_state['epoch'] != 0:
                epoch_itr = reload_train(args, epoch_itr, max_positions, task)

            epoch_itr.load_state_dict(epoch_itr_state)
Myle Ott's avatar
Myle Ott committed
387
388
389

            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
alexeib's avatar
alexeib committed
390

Myle Ott's avatar
Myle Ott committed
391
392
            trainer.lr_step(epoch_itr.epoch)
            trainer.lr_step_update(trainer.get_num_updates())
393
            if 'best' in extra_state and not args.reset_optimizer:
394
                save_checkpoint.best = extra_state['best']
395
        return True
396
397
    else:
        print('| no existing checkpoint found {}'.format(checkpoint_path))
398
    return False
399

400

401
def distributed_main(i, args, start_rank=0):
Myle Ott's avatar
Myle Ott committed
402
403
    args.device_id = i
    if args.distributed_rank is None:  # torch.multiprocessing.spawn
404
405
        args.distributed_rank = start_rank + i
    main(args, init_distributed=True)
Myle Ott's avatar
Myle Ott committed
406
407


Myle Ott's avatar
Myle Ott committed
408
def cli_main():
Myle Ott's avatar
Myle Ott committed
409
410
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
411

Myle Ott's avatar
Myle Ott committed
412
413
    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)
414

Myle Ott's avatar
Myle Ott committed
415
416
    if args.distributed_init_method is not None:
        # distributed training
417
418
419
420
421
422
423
424
425
426
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
427
    elif args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
428
        # fallback for single node with multiple GPUs
429
        assert args.distributed_world_size <= torch.cuda.device_count()
430
431
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
Myle Ott's avatar
Myle Ott committed
432
        args.distributed_rank = None  # set based on device id
Myle Ott's avatar
Myle Ott committed
433
434
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
Myle Ott's avatar
Myle Ott committed
435
436
437
438
439
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
440
    else:
Myle Ott's avatar
Myle Ott committed
441
        # single GPU training
442
        main(args)
Myle Ott's avatar
Myle Ott committed
443
444
445
446


if __name__ == '__main__':
    cli_main()