train.py 15.1 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Train a new model on one or across multiple GPUs.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

12
import collections
Myle Ott's avatar
Myle Ott committed
13
import itertools
14
15
import os
import math
16
17
import random

18
import torch
Sergey Edunov's avatar
Sergey Edunov committed
19

20
from fairseq import distributed_utils, options, progress_bar, tasks, utils
21
from fairseq.data import iterators
22
23
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
24
from fairseq.utils import import_user_module
Sergey Edunov's avatar
Sergey Edunov committed
25

Myle Ott's avatar
Myle Ott committed
26

27
def main(args, init_distributed=False):
28
29
    import_user_module(args)

30
31
    if args.max_tokens is None:
        args.max_tokens = 6000
32
33
    print(args)

Myle Ott's avatar
Myle Ott committed
34
35
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
36
37
    torch.manual_seed(args.seed)

Myle Ott's avatar
Myle Ott committed
38
39
    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
40

Myle Ott's avatar
Myle Ott committed
41
    # Load dataset splits
Alexei Baevski's avatar
Alexei Baevski committed
42
    load_dataset_splits(task, ['train', 'valid'])
43

44
45
46
47
48
49
    # Initialize distributed training (after data loading)
    if init_distributed:
        import socket
        args.distributed_rank = distributed_utils.distributed_init(args)
        print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))

Myle Ott's avatar
Myle Ott committed
50
51
52
    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
53
    print(model)
54
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
55
56
57
58
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
59

60
61
62
63
64
65
66
67
    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
68
    oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)
69

70
    # Build trainer
71
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
72
73
74
75
76
77
78
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
79
    epoch_itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
80
        dataset=task.dataset(args.train_subset),
81
        max_tokens=args.max_tokens,
82
        max_sentences=args.max_sentences,
Myle Ott's avatar
Myle Ott committed
83
84
85
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
86
87
        seed=args.seed,
        num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
88
        shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
89
        num_workers=args.num_workers,
90
91
92
    )

    # Load the latest checkpoint if one is available
93
    if not load_checkpoint(args, trainer, epoch_itr):
94
        trainer.dummy_train_step([dummy_batch])
95
96
97
98
99
100
101

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
Myle Ott's avatar
Myle Ott committed
102
    valid_losses = [None]
103
    valid_subsets = args.valid_subset.split(',')
Myle Ott's avatar
Myle Ott committed
104
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
105
        # train for one epoch
Myle Ott's avatar
Myle Ott committed
106
        train(args, trainer, task, epoch_itr)
107

Myle Ott's avatar
Myle Ott committed
108
109
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
110
111

        # only use first validation loss to update the learning rate
Myle Ott's avatar
Myle Ott committed
112
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
113
114

        # save checkpoint
Myle Ott's avatar
Myle Ott committed
115
116
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
117
118
119
120
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))


Myle Ott's avatar
Myle Ott committed
121
def train(args, trainer, task, epoch_itr):
122
123
    """Train the model for one epoch."""

124
    # Update parameters every N batches
125

126
    # Initialize data iterator
Myle Ott's avatar
Myle Ott committed
127
    itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
Myle Ott's avatar
Myle Ott committed
128
129
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
130
131
132
133
134
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

135
    extra_meters = collections.defaultdict(lambda: AverageMeter())
136
    first_valid = args.valid_subset.split(',')[0]
137
    max_update = args.max_update or math.inf
138
139
140
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
141
142
143
144
145
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
146
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
147
148
149
150
151
152
153
154
155
156
157
158
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

159
        num_updates = trainer.get_num_updates()
160
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
Myle Ott's avatar
Myle Ott committed
161
162
            valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
163
164

        if num_updates >= max_update:
165
166
167
168
169
170
171
172
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

Myle Ott's avatar
Myle Ott committed
173
    # reset training meters
174
175
176
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
Myle Ott's avatar
Myle Ott committed
177
178
179
180
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

def get_training_stats(trainer):
    stats = collections.OrderedDict()
    stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg)
    if trainer.get_meter('train_nll_loss').count > 0:
        nll_loss = trainer.get_meter('train_nll_loss').avg
        stats['nll_loss'] = '{:.3f}'.format(nll_loss)
    else:
        nll_loss = trainer.get_meter('train_loss').avg
    stats['ppl'] = get_perplexity(nll_loss)
    stats['wps'] = round(trainer.get_meter('wps').avg)
    stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg)
    stats['wpb'] = round(trainer.get_meter('wpb').avg)
    stats['bsz'] = round(trainer.get_meter('bsz').avg)
    stats['num_updates'] = trainer.get_num_updates()
    stats['lr'] = trainer.get_lr()
    stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
    stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
    stats['oom'] = trainer.get_meter('oom').avg
    if trainer.get_meter('loss_scale') is not None:
        stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
    stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
Myle Ott's avatar
Myle Ott committed
203
    stats['train_wall'] = round(trainer.get_meter('train_wall').sum)
204
205
206
    return stats


Myle Ott's avatar
Myle Ott committed
207
def validate(args, trainer, task, epoch_itr, subsets):
208
209
210
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
Myle Ott's avatar
Myle Ott committed
211
        # Initialize data iterator
212
        itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
213
            dataset=task.dataset(subset),
214
215
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
216
217
218
219
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
Myle Ott's avatar
Myle Ott committed
220
221
222
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
223
            num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
224
            shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
225
            num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
226
        ).next_epoch_itr(shuffle=False)
227
        progress = progress_bar.build_progress_bar(
Myle Ott's avatar
Myle Ott committed
228
            args, itr, epoch_itr.epoch,
229
230
231
232
233
234
235
236
237
238
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
Myle Ott's avatar
Myle Ott committed
239

240
241
242
243
        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
244
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
245
246
                    continue
                extra_meters[k].update(v)
247

248
249
250
251
252
        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats)
253

254
255
        valid_losses.append(stats['valid_loss'])
    return valid_losses
256
257
258
259
260
261
262
263


def get_valid_stats(trainer):
    stats = collections.OrderedDict()
    stats['valid_loss'] = trainer.get_meter('valid_loss').avg
    if trainer.get_meter('valid_nll_loss').count > 0:
        nll_loss = trainer.get_meter('valid_nll_loss').avg
        stats['valid_nll_loss'] = nll_loss
264
    else:
265
266
        nll_loss = trainer.get_meter('valid_loss').avg
    stats['valid_ppl'] = get_perplexity(nll_loss)
Myle Ott's avatar
Nits  
Myle Ott committed
267
268
269
    stats['num_updates'] = trainer.get_num_updates()
    if hasattr(save_checkpoint, 'best'):
        stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
270
271
272
273
274
275
276
277
278
279
    return stats


def get_perplexity(loss):
    try:
        return '{:.2f}'.format(math.pow(2, loss))
    except OverflowError:
        return float('inf')


Myle Ott's avatar
Myle Ott committed
280
281
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
282
        return
Myle Ott's avatar
Myle Ott committed
283
284
    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
285
286
287
288
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
Alexei Baevski's avatar
Alexei Baevski committed
289
290
            end_of_epoch and not args.no_epoch_checkpoints and
            epoch % args.save_interval == 0
291
292
    )
    checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
Alexei Baevski's avatar
Alexei Baevski committed
293
294
            not end_of_epoch and args.save_interval_updates > 0 and
            updates % args.save_interval_updates == 0
295
296
    )
    checkpoint_conds['checkpoint_best.pt'] = (
Alexei Baevski's avatar
Alexei Baevski committed
297
298
            val_loss is not None and
            (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
299
300
301
    )
    checkpoint_conds['checkpoint_last.pt'] = True  # keep this last so that it's a symlink

Myle Ott's avatar
Myle Ott committed
302
303
304
    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
305
    extra_state = {
Myle Ott's avatar
Myle Ott committed
306
        'train_iterator': epoch_itr.state_dict(),
307
308
        'val_loss': val_loss,
    }
Naman Goyal's avatar
Naman Goyal committed
309
310
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})
311

312
313
    checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
    if len(checkpoints) > 0:
314
315
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)
316
317
318
319

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
320
        for old_chk in checkpoints[args.keep_interval_updates:]:
Myle Ott's avatar
Myle Ott committed
321
322
323
324
325
326
327
328
329
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt')
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
330
331


Myle Ott's avatar
Myle Ott committed
332
333
def load_checkpoint(args, trainer, epoch_itr):
    """Load a checkpoint and replay dataloader to match."""
334
    os.makedirs(args.save_dir, exist_ok=True)
335
336
337
338
    if os.path.isabs(args.restore_file):
        checkpoint_path = args.restore_file
    else:
        checkpoint_path = os.path.join(args.save_dir, args.restore_file)
339
    if os.path.isfile(checkpoint_path):
340
341
        extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
                                              eval(args.optimizer_overrides))
342
        if extra_state is not None:
Myle Ott's avatar
Myle Ott committed
343
344
345
346
347
            # replay train iterator to match checkpoint
            epoch_itr.load_state_dict(extra_state['train_iterator'])

            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
alexeib's avatar
alexeib committed
348

Myle Ott's avatar
Myle Ott committed
349
350
            trainer.lr_step(epoch_itr.epoch)
            trainer.lr_step_update(trainer.get_num_updates())
351
352
            if 'best' in extra_state:
                save_checkpoint.best = extra_state['best']
353
        return True
354
355
    else:
        print('| no existing checkpoint found {}'.format(checkpoint_path))
356
    return False
357

358

Alexei Baevski's avatar
Alexei Baevski committed
359
def load_dataset_splits(task, splits):
Myle Ott's avatar
Myle Ott committed
360
    for split in splits:
Alexei Baevski's avatar
Alexei Baevski committed
361
362
363
364
365
366
367
368
369
370
371
        if split == 'train':
            task.load_dataset(split, combine=True)
        else:
            for k in itertools.count():
                split_k = split + (str(k) if k > 0 else '')
                try:
                    task.load_dataset(split_k, combine=False)
                except FileNotFoundError as e:
                    if k > 0:
                        break
                    raise e
Sergey Edunov's avatar
Sergey Edunov committed
372

Myle Ott's avatar
Myle Ott committed
373

Myle Ott's avatar
Myle Ott committed
374
375
376
377
def distributed_main(i, args):
    args.device_id = i
    if args.distributed_rank is None:  # torch.multiprocessing.spawn
        args.distributed_rank = i
378
    main(args, init_distributed=True)
Myle Ott's avatar
Myle Ott committed
379
380


Myle Ott's avatar
Myle Ott committed
381
def cli_main():
Myle Ott's avatar
Myle Ott committed
382
383
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
384

Myle Ott's avatar
Myle Ott committed
385
386
    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)
387

Myle Ott's avatar
Myle Ott committed
388
389
390
    if args.distributed_init_method is not None:
        # distributed training
        distributed_main(args.device_id, args)
391
    elif args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
392
        # fallback for single node with multiple GPUs
393
394
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
Myle Ott's avatar
Myle Ott committed
395
        args.distributed_rank = None  # set based on device id
Myle Ott's avatar
Myle Ott committed
396
397
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
Myle Ott's avatar
Myle Ott committed
398
399
400
401
402
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
403
    else:
Myle Ott's avatar
Myle Ott committed
404
        # single GPU training
405
        main(args)
Myle Ott's avatar
Myle Ott committed
406
407
408
409


if __name__ == '__main__':
    cli_main()