train.py 15.6 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Train a new model on one or across multiple GPUs.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

12
import collections
Myle Ott's avatar
Myle Ott committed
13
import itertools
14
import math
Myle Ott's avatar
Myle Ott committed
15
import os
16
import random
Myle Ott's avatar
Myle Ott committed
17
import shutil
18

19
import torch
Sergey Edunov's avatar
Sergey Edunov committed
20

Myle Ott's avatar
Myle Ott committed
21
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
22
from fairseq.data import iterators
23
24
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
Sergey Edunov's avatar
Sergey Edunov committed
25

Myle Ott's avatar
Myle Ott committed
26

27
def main(args, init_distributed=False):
Myle Ott's avatar
Myle Ott committed
28
    utils.import_user_module(args)
29

30
31
    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
32

33
    # Initialize CUDA and distributed training
Myle Ott's avatar
Myle Ott committed
34
35
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
36
    torch.manual_seed(args.seed)
37
38
39
40
41
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    # Print args
    print(args)
42

Myle Ott's avatar
Myle Ott committed
43
44
    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
45

Myle Ott's avatar
Myle Ott committed
46
    # Load dataset splits
47
    load_dataset_splits(args, task)
48

Myle Ott's avatar
Myle Ott committed
49
50
51
    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
52
    print(model)
53
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
54
55
56
57
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
58
59

    # Build trainer
Myle Ott's avatar
Myle Ott committed
60
    trainer = Trainer(args, task, model, criterion)
61
62
63
64
65
66
67
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
68
    epoch_itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
69
        dataset=task.dataset(args.train_subset),
70
        max_tokens=args.max_tokens,
71
        max_sentences=args.max_sentences,
Myle Ott's avatar
Myle Ott committed
72
73
74
75
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            model.max_positions(),
        ),
Myle Ott's avatar
Myle Ott committed
76
        ignore_invalid_inputs=True,
77
        required_batch_size_multiple=args.required_batch_size_multiple,
78
79
        seed=args.seed,
        num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
80
        shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
81
        num_workers=args.num_workers,
82
83
84
    )

    # Load the latest checkpoint if one is available
Myle Ott's avatar
Myle Ott committed
85
    load_checkpoint(args, trainer, epoch_itr)
86
87
88
89
90
91
92

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
Myle Ott's avatar
Myle Ott committed
93
    valid_losses = [None]
94
    valid_subsets = args.valid_subset.split(',')
Myle Ott's avatar
Myle Ott committed
95
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
96
        # train for one epoch
Myle Ott's avatar
Myle Ott committed
97
        train(args, trainer, task, epoch_itr)
98

Myle Ott's avatar
Myle Ott committed
99
100
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
101
102

        # only use first validation loss to update the learning rate
Myle Ott's avatar
Myle Ott committed
103
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
104
105

        # save checkpoint
Myle Ott's avatar
Myle Ott committed
106
107
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
108
109
110
111
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))


Myle Ott's avatar
Myle Ott committed
112
def train(args, trainer, task, epoch_itr):
113
    """Train the model for one epoch."""
114
    # Update parameters every N batches
Myle Ott's avatar
Myle Ott committed
115
116
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
Myle Ott's avatar
Myle Ott committed
117
118
119
120
121
122

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
123
124
125
126
127
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

128
    extra_meters = collections.defaultdict(lambda: AverageMeter())
129
    valid_subsets = args.valid_subset.split(',')
130
    max_update = args.max_update or math.inf
131
132
133
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
134
135
136
137
138
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
139
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
140
141
142
143
144
145
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
Myle Ott's avatar
Myle Ott committed
146
        progress.log(stats, tag='train', step=stats['num_updates'])
147
148
149
150
151

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

152
        num_updates = trainer.get_num_updates()
153
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
154
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
Myle Ott's avatar
Myle Ott committed
155
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
156
157

        if num_updates >= max_update:
158
159
160
161
162
163
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
164
    progress.print(stats, tag='train', step=stats['num_updates'])
165

Myle Ott's avatar
Myle Ott committed
166
    # reset training meters
167
168
169
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
Myle Ott's avatar
Myle Ott committed
170
171
172
173
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

174
175
176

def get_training_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
177
    stats['loss'] = trainer.get_meter('train_loss')
178
    if trainer.get_meter('train_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
179
180
        nll_loss = trainer.get_meter('train_nll_loss')
        stats['nll_loss'] = nll_loss
181
    else:
Myle Ott's avatar
Myle Ott committed
182
183
184
185
186
187
        nll_loss = trainer.get_meter('train_loss')
    stats['ppl'] = get_perplexity(nll_loss.avg)
    stats['wps'] = trainer.get_meter('wps')
    stats['ups'] = trainer.get_meter('ups')
    stats['wpb'] = trainer.get_meter('wpb')
    stats['bsz'] = trainer.get_meter('bsz')
188
189
    stats['num_updates'] = trainer.get_num_updates()
    stats['lr'] = trainer.get_lr()
Myle Ott's avatar
Myle Ott committed
190
191
192
    stats['gnorm'] = trainer.get_meter('gnorm')
    stats['clip'] = trainer.get_meter('clip')
    stats['oom'] = trainer.get_meter('oom')
193
    if trainer.get_meter('loss_scale') is not None:
Myle Ott's avatar
Myle Ott committed
194
        stats['loss_scale'] = trainer.get_meter('loss_scale')
195
    stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
Myle Ott's avatar
Myle Ott committed
196
    stats['train_wall'] = trainer.get_meter('train_wall')
197
198
199
    return stats


Myle Ott's avatar
Myle Ott committed
200
def validate(args, trainer, task, epoch_itr, subsets):
201
202
203
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
Myle Ott's avatar
Myle Ott committed
204
        # Initialize data iterator
205
        itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
206
            dataset=task.dataset(subset),
207
208
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
209
210
211
212
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
Myle Ott's avatar
Myle Ott committed
213
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
214
            required_batch_size_multiple=args.required_batch_size_multiple,
Myle Ott's avatar
Myle Ott committed
215
            seed=args.seed,
216
            num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
217
            shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
218
            num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
219
        ).next_epoch_itr(shuffle=False)
220
        progress = progress_bar.build_progress_bar(
Myle Ott's avatar
Myle Ott committed
221
            args, itr, epoch_itr.epoch,
222
223
224
225
226
227
228
229
230
231
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
Myle Ott's avatar
Myle Ott committed
232

233
234
235
236
        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
237
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
238
239
                    continue
                extra_meters[k].update(v)
240

241
242
243
244
        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
245
        progress.print(stats, tag=subset, step=trainer.get_num_updates())
246

Myle Ott's avatar
Myle Ott committed
247
        valid_losses.append(stats['loss'].avg)
248
    return valid_losses
249
250
251
252


def get_valid_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
253
    stats['loss'] = trainer.get_meter('valid_loss')
254
    if trainer.get_meter('valid_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
255
256
        nll_loss = trainer.get_meter('valid_nll_loss')
        stats['nll_loss'] = nll_loss
257
    else:
Myle Ott's avatar
Myle Ott committed
258
259
        nll_loss = stats['loss']
    stats['ppl'] = get_perplexity(nll_loss.avg)
Myle Ott's avatar
Nits  
Myle Ott committed
260
261
    stats['num_updates'] = trainer.get_num_updates()
    if hasattr(save_checkpoint, 'best'):
Myle Ott's avatar
Myle Ott committed
262
        stats['best_loss'] = min(save_checkpoint.best, stats['loss'].avg)
263
264
265
266
267
268
269
270
271
272
    return stats


def get_perplexity(loss):
    try:
        return '{:.2f}'.format(math.pow(2, loss))
    except OverflowError:
        return float('inf')


Myle Ott's avatar
Myle Ott committed
273
274
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
275
        return
Myle Ott's avatar
Myle Ott committed
276
277
278
279

    write_timer = StopwatchMeter()
    write_timer.start()

Myle Ott's avatar
Myle Ott committed
280
281
    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
282
283
284
285
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
Myle Ott's avatar
Myle Ott committed
286
287
        end_of_epoch and not args.no_epoch_checkpoints and
        epoch % args.save_interval == 0
288
289
    )
    checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
Myle Ott's avatar
Myle Ott committed
290
291
        not end_of_epoch and args.save_interval_updates > 0 and
        updates % args.save_interval_updates == 0
292
293
    )
    checkpoint_conds['checkpoint_best.pt'] = (
Myle Ott's avatar
Myle Ott committed
294
295
        val_loss is not None and
        (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
296
297
298
    )
    checkpoint_conds['checkpoint_last.pt'] = True  # keep this last so that it's a symlink

Myle Ott's avatar
Myle Ott committed
299
300
301
    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
302
    extra_state = {
Myle Ott's avatar
Myle Ott committed
303
        'train_iterator': epoch_itr.state_dict(),
304
305
        'val_loss': val_loss,
    }
Naman Goyal's avatar
Naman Goyal committed
306
307
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})
308

309
310
    checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
    if len(checkpoints) > 0:
Myle Ott's avatar
Myle Ott committed
311
312
313
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            shutil.copyfile(checkpoints[0], cp)
314

freewym's avatar
freewym committed
315
316
317
318
        write_timer.stop()
        print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
            checkpoints[0], epoch, updates, write_timer.sum))

319
320
    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
Myle Ott's avatar
Myle Ott committed
321
322
323
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
        )
324
        for old_chk in checkpoints[args.keep_interval_updates:]:
Myle Ott's avatar
Myle Ott committed
325
326
327
328
329
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
Myle Ott's avatar
Myle Ott committed
330
331
332
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint(\d+)\.pt',
        )
Myle Ott's avatar
Myle Ott committed
333
334
335
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
336
337


Myle Ott's avatar
Myle Ott committed
338
339
def load_checkpoint(args, trainer, epoch_itr):
    """Load a checkpoint and replay dataloader to match."""
340
341
342
343
344

    # Only rank 0 should attempt to create the required dir
    if args.distributed_rank == 0:
        os.makedirs(args.save_dir, exist_ok=True)

345
346
347
348
    if os.path.isabs(args.restore_file):
        checkpoint_path = args.restore_file
    else:
        checkpoint_path = os.path.join(args.save_dir, args.restore_file)
349
    if os.path.isfile(checkpoint_path):
350
351
        extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
                                              eval(args.optimizer_overrides))
352
        if extra_state is not None:
Myle Ott's avatar
Myle Ott committed
353
354
355
356
357
            # replay train iterator to match checkpoint
            epoch_itr.load_state_dict(extra_state['train_iterator'])

            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
alexeib's avatar
alexeib committed
358

Myle Ott's avatar
Myle Ott committed
359
360
            trainer.lr_step(epoch_itr.epoch)
            trainer.lr_step_update(trainer.get_num_updates())
361
            if 'best' in extra_state and not args.reset_optimizer:
362
                save_checkpoint.best = extra_state['best']
363
        return True
364
365
    else:
        print('| no existing checkpoint found {}'.format(checkpoint_path))
366
    return False
367

368

369
370
371
372
373
374
375
376
377
378
379
def load_dataset_splits(args, task):
    task.load_dataset(args.train_subset, combine=True)
    for split in args.valid_subset.split(','):
        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            try:
                task.load_dataset(split_k, combine=False)
            except FileNotFoundError as e:
                if k > 0:
                    break
                raise e
Sergey Edunov's avatar
Sergey Edunov committed
380

Myle Ott's avatar
Myle Ott committed
381

382
def distributed_main(i, args, start_rank=0):
Myle Ott's avatar
Myle Ott committed
383
384
    args.device_id = i
    if args.distributed_rank is None:  # torch.multiprocessing.spawn
385
386
        args.distributed_rank = start_rank + i
    main(args, init_distributed=True)
Myle Ott's avatar
Myle Ott committed
387
388


Myle Ott's avatar
Myle Ott committed
389
def cli_main():
Myle Ott's avatar
Myle Ott committed
390
391
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
392

Myle Ott's avatar
Myle Ott committed
393
394
    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)
395

Myle Ott's avatar
Myle Ott committed
396
397
    if args.distributed_init_method is not None:
        # distributed training
398
399
400
401
402
403
404
405
406
407
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
408
    elif args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
409
        # fallback for single node with multiple GPUs
410
        assert args.distributed_world_size <= torch.cuda.device_count()
411
412
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
Myle Ott's avatar
Myle Ott committed
413
        args.distributed_rank = None  # set based on device id
Myle Ott's avatar
Myle Ott committed
414
415
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
Myle Ott's avatar
Myle Ott committed
416
417
418
419
420
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
421
    else:
Myle Ott's avatar
Myle Ott committed
422
        # single GPU training
423
        main(args)
Myle Ott's avatar
Myle Ott committed
424
425
426
427


if __name__ == '__main__':
    cli_main()