train.py 15.5 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Train a new model on one or across multiple GPUs.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

12
import collections
Myle Ott's avatar
Myle Ott committed
13
import itertools
14
import math
Myle Ott's avatar
Myle Ott committed
15
import os
16
17
import random

18
import torch
Sergey Edunov's avatar
Sergey Edunov committed
19

Myle Ott's avatar
Myle Ott committed
20
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
21
from fairseq.data import iterators
22
23
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
Sergey Edunov's avatar
Sergey Edunov committed
24

Myle Ott's avatar
Myle Ott committed
25

26
def main(args, init_distributed=False):
Myle Ott's avatar
Myle Ott committed
27
    utils.import_user_module(args)
28

29
30
    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
31

32
    # Initialize CUDA and distributed training
Myle Ott's avatar
Myle Ott committed
33
34
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
35
    torch.manual_seed(args.seed)
36
37
38
39
40
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    # Print args
    print(args)
41

Myle Ott's avatar
Myle Ott committed
42
43
    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
44

Myle Ott's avatar
Myle Ott committed
45
    # Load dataset splits
46
    load_dataset_splits(args, task)
47

Myle Ott's avatar
Myle Ott committed
48
49
50
    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
51
    print(model)
52
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
53
54
55
56
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
57
58

    # Build trainer
Myle Ott's avatar
Myle Ott committed
59
    trainer = Trainer(args, task, model, criterion)
60
61
62
63
64
65
66
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
67
    epoch_itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
68
        dataset=task.dataset(args.train_subset),
69
        max_tokens=args.max_tokens,
70
        max_sentences=args.max_sentences,
Myle Ott's avatar
Myle Ott committed
71
72
73
74
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            model.max_positions(),
        ),
Myle Ott's avatar
Myle Ott committed
75
        ignore_invalid_inputs=True,
76
        required_batch_size_multiple=args.required_batch_size_multiple,
77
78
        seed=args.seed,
        num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
79
        shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
80
        num_workers=args.num_workers,
81
82
83
    )

    # Load the latest checkpoint if one is available
Myle Ott's avatar
Myle Ott committed
84
    load_checkpoint(args, trainer, epoch_itr)
85
86
87
88
89
90
91

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
Myle Ott's avatar
Myle Ott committed
92
    valid_losses = [None]
93
    valid_subsets = args.valid_subset.split(',')
Myle Ott's avatar
Myle Ott committed
94
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
95
        # train for one epoch
Myle Ott's avatar
Myle Ott committed
96
        train(args, trainer, task, epoch_itr)
97

Myle Ott's avatar
Myle Ott committed
98
99
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
100
101

        # only use first validation loss to update the learning rate
Myle Ott's avatar
Myle Ott committed
102
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
103
104

        # save checkpoint
Myle Ott's avatar
Myle Ott committed
105
106
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
107
108
109
110
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))


Myle Ott's avatar
Myle Ott committed
111
def train(args, trainer, task, epoch_itr):
112
    """Train the model for one epoch."""
113
    # Update parameters every N batches
Myle Ott's avatar
Myle Ott committed
114
115
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
Myle Ott's avatar
Myle Ott committed
116
117
118
119
120
121

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
122
123
124
125
126
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

127
    extra_meters = collections.defaultdict(lambda: AverageMeter())
128
    valid_subsets = args.valid_subset.split(',')
129
    max_update = args.max_update or math.inf
130
131
132
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
133
134
135
136
137
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
138
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
139
140
141
142
143
144
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
Myle Ott's avatar
Myle Ott committed
145
        progress.log(stats, tag='train', step=stats['num_updates'])
146
147
148
149
150

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

151
        num_updates = trainer.get_num_updates()
152
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
153
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
Myle Ott's avatar
Myle Ott committed
154
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
155
156

        if num_updates >= max_update:
157
158
159
160
161
162
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
163
    progress.print(stats, tag='train', step=stats['num_updates'])
164

Myle Ott's avatar
Myle Ott committed
165
    # reset training meters
166
167
168
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
Myle Ott's avatar
Myle Ott committed
169
170
171
172
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

173
174
175

def get_training_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
176
    stats['loss'] = trainer.get_meter('train_loss')
177
    if trainer.get_meter('train_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
178
179
        nll_loss = trainer.get_meter('train_nll_loss')
        stats['nll_loss'] = nll_loss
180
    else:
Myle Ott's avatar
Myle Ott committed
181
182
183
184
185
186
        nll_loss = trainer.get_meter('train_loss')
    stats['ppl'] = get_perplexity(nll_loss.avg)
    stats['wps'] = trainer.get_meter('wps')
    stats['ups'] = trainer.get_meter('ups')
    stats['wpb'] = trainer.get_meter('wpb')
    stats['bsz'] = trainer.get_meter('bsz')
187
188
    stats['num_updates'] = trainer.get_num_updates()
    stats['lr'] = trainer.get_lr()
Myle Ott's avatar
Myle Ott committed
189
190
191
    stats['gnorm'] = trainer.get_meter('gnorm')
    stats['clip'] = trainer.get_meter('clip')
    stats['oom'] = trainer.get_meter('oom')
192
    if trainer.get_meter('loss_scale') is not None:
Myle Ott's avatar
Myle Ott committed
193
        stats['loss_scale'] = trainer.get_meter('loss_scale')
194
    stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
Myle Ott's avatar
Myle Ott committed
195
    stats['train_wall'] = trainer.get_meter('train_wall')
196
197
198
    return stats


Myle Ott's avatar
Myle Ott committed
199
def validate(args, trainer, task, epoch_itr, subsets):
200
201
202
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
Myle Ott's avatar
Myle Ott committed
203
        # Initialize data iterator
204
        itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
205
            dataset=task.dataset(subset),
206
207
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
208
209
210
211
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
Myle Ott's avatar
Myle Ott committed
212
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
213
            required_batch_size_multiple=args.required_batch_size_multiple,
Myle Ott's avatar
Myle Ott committed
214
            seed=args.seed,
215
            num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
216
            shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
217
            num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
218
        ).next_epoch_itr(shuffle=False)
219
        progress = progress_bar.build_progress_bar(
Myle Ott's avatar
Myle Ott committed
220
            args, itr, epoch_itr.epoch,
221
222
223
224
225
226
227
228
229
230
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
Myle Ott's avatar
Myle Ott committed
231

232
233
234
235
        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
236
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
237
238
                    continue
                extra_meters[k].update(v)
239

240
241
242
243
        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
244
        progress.print(stats, tag=subset, step=trainer.get_num_updates())
245

Myle Ott's avatar
Myle Ott committed
246
        valid_losses.append(stats['loss'].avg)
247
    return valid_losses
248
249
250
251


def get_valid_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
252
    stats['loss'] = trainer.get_meter('valid_loss')
253
    if trainer.get_meter('valid_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
254
255
        nll_loss = trainer.get_meter('valid_nll_loss')
        stats['nll_loss'] = nll_loss
256
    else:
Myle Ott's avatar
Myle Ott committed
257
258
        nll_loss = stats['loss']
    stats['ppl'] = get_perplexity(nll_loss.avg)
Myle Ott's avatar
Nits  
Myle Ott committed
259
260
    stats['num_updates'] = trainer.get_num_updates()
    if hasattr(save_checkpoint, 'best'):
Myle Ott's avatar
Myle Ott committed
261
        stats['best_loss'] = min(save_checkpoint.best, stats['loss'].avg)
262
263
264
265
266
267
268
269
270
271
    return stats


def get_perplexity(loss):
    try:
        return '{:.2f}'.format(math.pow(2, loss))
    except OverflowError:
        return float('inf')


Myle Ott's avatar
Myle Ott committed
272
273
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
274
        return
Myle Ott's avatar
Myle Ott committed
275
276
277
278

    write_timer = StopwatchMeter()
    write_timer.start()

Myle Ott's avatar
Myle Ott committed
279
280
    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
281
282
283
284
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
Alexei Baevski's avatar
Alexei Baevski committed
285
286
            end_of_epoch and not args.no_epoch_checkpoints and
            epoch % args.save_interval == 0
287
288
    )
    checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
Alexei Baevski's avatar
Alexei Baevski committed
289
290
            not end_of_epoch and args.save_interval_updates > 0 and
            updates % args.save_interval_updates == 0
291
292
    )
    checkpoint_conds['checkpoint_best.pt'] = (
Alexei Baevski's avatar
Alexei Baevski committed
293
294
            val_loss is not None and
            (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
295
296
297
    )
    checkpoint_conds['checkpoint_last.pt'] = True  # keep this last so that it's a symlink

Myle Ott's avatar
Myle Ott committed
298
299
300
    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
301
    extra_state = {
Myle Ott's avatar
Myle Ott committed
302
        'train_iterator': epoch_itr.state_dict(),
303
304
        'val_loss': val_loss,
    }
Naman Goyal's avatar
Naman Goyal committed
305
306
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})
307

308
309
    checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
    if len(checkpoints) > 0:
310
311
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)
312

freewym's avatar
freewym committed
313
314
315
316
        write_timer.stop()
        print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
            checkpoints[0], epoch, updates, write_timer.sum))

317
318
    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
Myle Ott's avatar
Myle Ott committed
319
320
321
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
        )
322
        for old_chk in checkpoints[args.keep_interval_updates:]:
Myle Ott's avatar
Myle Ott committed
323
324
325
326
327
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
Myle Ott's avatar
Myle Ott committed
328
329
330
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint(\d+)\.pt',
        )
Myle Ott's avatar
Myle Ott committed
331
332
333
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
334
335


Myle Ott's avatar
Myle Ott committed
336
337
def load_checkpoint(args, trainer, epoch_itr):
    """Load a checkpoint and replay dataloader to match."""
338
339
340
341
342

    # Only rank 0 should attempt to create the required dir
    if args.distributed_rank == 0:
        os.makedirs(args.save_dir, exist_ok=True)

343
344
345
346
    if os.path.isabs(args.restore_file):
        checkpoint_path = args.restore_file
    else:
        checkpoint_path = os.path.join(args.save_dir, args.restore_file)
347
    if os.path.isfile(checkpoint_path):
348
349
        extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
                                              eval(args.optimizer_overrides))
350
        if extra_state is not None:
Myle Ott's avatar
Myle Ott committed
351
352
353
354
355
            # replay train iterator to match checkpoint
            epoch_itr.load_state_dict(extra_state['train_iterator'])

            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
alexeib's avatar
alexeib committed
356

Myle Ott's avatar
Myle Ott committed
357
358
            trainer.lr_step(epoch_itr.epoch)
            trainer.lr_step_update(trainer.get_num_updates())
359
            if 'best' in extra_state and not args.reset_optimizer:
360
                save_checkpoint.best = extra_state['best']
361
        return True
362
363
    else:
        print('| no existing checkpoint found {}'.format(checkpoint_path))
364
    return False
365

366

367
368
369
370
371
372
373
374
375
376
377
def load_dataset_splits(args, task):
    task.load_dataset(args.train_subset, combine=True)
    for split in args.valid_subset.split(','):
        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            try:
                task.load_dataset(split_k, combine=False)
            except FileNotFoundError as e:
                if k > 0:
                    break
                raise e
Sergey Edunov's avatar
Sergey Edunov committed
378

Myle Ott's avatar
Myle Ott committed
379

380
def distributed_main(i, args, start_rank=0):
Myle Ott's avatar
Myle Ott committed
381
382
    args.device_id = i
    if args.distributed_rank is None:  # torch.multiprocessing.spawn
383
384
        args.distributed_rank = start_rank + i
    main(args, init_distributed=True)
Myle Ott's avatar
Myle Ott committed
385
386


Myle Ott's avatar
Myle Ott committed
387
def cli_main():
Myle Ott's avatar
Myle Ott committed
388
389
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
390

Myle Ott's avatar
Myle Ott committed
391
392
    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)
393

Myle Ott's avatar
Myle Ott committed
394
395
    if args.distributed_init_method is not None:
        # distributed training
396
397
398
399
400
401
402
403
404
405
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
406
    elif args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
407
        # fallback for single node with multiple GPUs
408
        assert args.distributed_world_size <= torch.cuda.device_count()
409
410
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
Myle Ott's avatar
Myle Ott committed
411
        args.distributed_rank = None  # set based on device id
Myle Ott's avatar
Myle Ott committed
412
413
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
Myle Ott's avatar
Myle Ott committed
414
415
416
417
418
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
419
    else:
Myle Ott's avatar
Myle Ott committed
420
        # single GPU training
421
        main(args)
Myle Ott's avatar
Myle Ott committed
422
423
424
425


if __name__ == '__main__':
    cli_main()