train.py 11.1 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Train a new model on one or across multiple GPUs.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

12
13
import collections
import math
Myle Ott's avatar
Myle Ott committed
14
import os
15
16
import random

17
import torch
Sergey Edunov's avatar
Sergey Edunov committed
18

Myle Ott's avatar
Myle Ott committed
19
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
20
from fairseq.data import iterators
21
22
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
Sergey Edunov's avatar
Sergey Edunov committed
23

Myle Ott's avatar
Myle Ott committed
24

25
def main(args, init_distributed=False):
Myle Ott's avatar
Myle Ott committed
26
    utils.import_user_module(args)
27

28
29
    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
30

31
    # Initialize CUDA and distributed training
Myle Ott's avatar
Myle Ott committed
32
33
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
34
    torch.manual_seed(args.seed)
35
36
37
38
39
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    # Print args
    print(args)
40

Myle Ott's avatar
Myle Ott committed
41
42
    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
43

Myle Ott's avatar
Myle Ott committed
44
    # Load dataset splits
Naman Goyal's avatar
Naman Goyal committed
45
46
47
    task.load_dataset(args.train_subset, combine=True, epoch=0)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=True, epoch=0)
48

Myle Ott's avatar
Myle Ott committed
49
50
51
    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
52
    print(model)
53
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
54
55
56
57
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
58
59

    # Build trainer
Myle Ott's avatar
Myle Ott committed
60
    trainer = Trainer(args, task, model, criterion)
61
62
63
64
65
66
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

Myle Ott's avatar
Myle Ott committed
67
68
69
    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
70
71
72
73
74
75
76

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
Myle Ott's avatar
Myle Ott committed
77
    valid_losses = [None]
78
    valid_subsets = args.valid_subset.split(',')
Myle Ott's avatar
Myle Ott committed
79
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
80
        # train for one epoch
Myle Ott's avatar
Myle Ott committed
81
        train(args, trainer, task, epoch_itr)
82

Myle Ott's avatar
Myle Ott committed
83
84
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
85
86

        # only use first validation loss to update the learning rate
Myle Ott's avatar
Myle Ott committed
87
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
88
89

        # save checkpoint
Myle Ott's avatar
Myle Ott committed
90
        if epoch_itr.epoch % args.save_interval == 0:
Myle Ott's avatar
Myle Ott committed
91
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
Naman Goyal's avatar
Naman Goyal committed
92

93
        if ':' in getattr(args, 'data', ''):
Myle Ott's avatar
Myle Ott committed
94
95
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
96
97
98
99
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))


Myle Ott's avatar
Myle Ott committed
100
def train(args, trainer, task, epoch_itr):
101
    """Train the model for one epoch."""
102
    # Update parameters every N batches
Myle Ott's avatar
Myle Ott committed
103
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
Myle Ott's avatar
Myle Ott committed
104
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
Myle Ott's avatar
Myle Ott committed
105
106
107
108
109
110

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
111
112
113
114
115
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

116
    extra_meters = collections.defaultdict(lambda: AverageMeter())
117
    valid_subsets = args.valid_subset.split(',')
118
    max_update = args.max_update or math.inf
119
120
121
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
122
123
124
125
126
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
127
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
128
129
130
131
132
133
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
Myle Ott's avatar
Myle Ott committed
134
        progress.log(stats, tag='train', step=stats['num_updates'])
135
136
137
138
139

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

140
        num_updates = trainer.get_num_updates()
141
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
142
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
143
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
144
145

        if num_updates >= max_update:
146
147
148
149
150
151
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
152
    progress.print(stats, tag='train', step=stats['num_updates'])
153

Myle Ott's avatar
Myle Ott committed
154
    # reset training meters
155
156
157
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
Myle Ott's avatar
Myle Ott committed
158
159
160
161
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

162
163
164

def get_training_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
165
    stats['loss'] = trainer.get_meter('train_loss')
166
    if trainer.get_meter('train_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
167
168
        nll_loss = trainer.get_meter('train_nll_loss')
        stats['nll_loss'] = nll_loss
169
    else:
Myle Ott's avatar
Myle Ott committed
170
        nll_loss = trainer.get_meter('train_loss')
171
    stats['ppl'] = utils.get_perplexity(nll_loss.avg)
Myle Ott's avatar
Myle Ott committed
172
173
174
175
    stats['wps'] = trainer.get_meter('wps')
    stats['ups'] = trainer.get_meter('ups')
    stats['wpb'] = trainer.get_meter('wpb')
    stats['bsz'] = trainer.get_meter('bsz')
176
177
    stats['num_updates'] = trainer.get_num_updates()
    stats['lr'] = trainer.get_lr()
Myle Ott's avatar
Myle Ott committed
178
179
180
    stats['gnorm'] = trainer.get_meter('gnorm')
    stats['clip'] = trainer.get_meter('clip')
    stats['oom'] = trainer.get_meter('oom')
181
    if trainer.get_meter('loss_scale') is not None:
Myle Ott's avatar
Myle Ott committed
182
        stats['loss_scale'] = trainer.get_meter('loss_scale')
183
    stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
Myle Ott's avatar
Myle Ott committed
184
    stats['train_wall'] = trainer.get_meter('train_wall')
185
186
187
    return stats


Myle Ott's avatar
Myle Ott committed
188
def validate(args, trainer, task, epoch_itr, subsets):
189
190
191
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
Myle Ott's avatar
Myle Ott committed
192
        # Initialize data iterator
193
        itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
194
            dataset=task.dataset(subset),
195
196
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
197
198
199
200
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
Myle Ott's avatar
Myle Ott committed
201
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
202
            required_batch_size_multiple=args.required_batch_size_multiple,
Myle Ott's avatar
Myle Ott committed
203
            seed=args.seed,
204
            num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
205
            shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
206
            num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
207
        ).next_epoch_itr(shuffle=False)
208
        progress = progress_bar.build_progress_bar(
Myle Ott's avatar
Myle Ott committed
209
            args, itr, epoch_itr.epoch,
210
211
212
213
214
215
216
217
218
219
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
Myle Ott's avatar
Myle Ott committed
220

221
222
223
224
        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
225
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
226
227
                    continue
                extra_meters[k].update(v)
228

229
230
231
232
        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
233
        progress.print(stats, tag=subset, step=trainer.get_num_updates())
234

Myle Ott's avatar
Myle Ott committed
235
        valid_losses.append(stats['loss'].avg)
236
    return valid_losses
237
238
239
240


def get_valid_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
241
    stats['loss'] = trainer.get_meter('valid_loss')
242
    if trainer.get_meter('valid_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
243
244
        nll_loss = trainer.get_meter('valid_nll_loss')
        stats['nll_loss'] = nll_loss
245
    else:
Myle Ott's avatar
Myle Ott committed
246
        nll_loss = stats['loss']
247
    stats['ppl'] = utils.get_perplexity(nll_loss.avg)
Myle Ott's avatar
Nits  
Myle Ott committed
248
    stats['num_updates'] = trainer.get_num_updates()
249
250
251
    if hasattr(checkpoint_utils.save_checkpoint, 'best'):
        stats['best_loss'] = min(
            checkpoint_utils.save_checkpoint.best, stats['loss'].avg)
252
253
254
    return stats


255
def distributed_main(i, args, start_rank=0):
Myle Ott's avatar
Myle Ott committed
256
257
    args.device_id = i
    if args.distributed_rank is None:  # torch.multiprocessing.spawn
258
259
        args.distributed_rank = start_rank + i
    main(args, init_distributed=True)
Myle Ott's avatar
Myle Ott committed
260
261


Myle Ott's avatar
Myle Ott committed
262
def cli_main():
Myle Ott's avatar
Myle Ott committed
263
264
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
265

Myle Ott's avatar
Myle Ott committed
266
267
    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)
268

Myle Ott's avatar
Myle Ott committed
269
270
    if args.distributed_init_method is not None:
        # distributed training
271
272
273
274
275
276
277
278
279
280
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
281
    elif args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
282
        # fallback for single node with multiple GPUs
283
        assert args.distributed_world_size <= torch.cuda.device_count()
284
285
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
Myle Ott's avatar
Myle Ott committed
286
        args.distributed_rank = None  # set based on device id
Myle Ott's avatar
Myle Ott committed
287
288
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
Myle Ott's avatar
Myle Ott committed
289
290
291
292
293
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
294
    else:
Myle Ott's avatar
Myle Ott committed
295
        # single GPU training
296
        main(args)
Myle Ott's avatar
Myle Ott committed
297
298
299
300


if __name__ == '__main__':
    cli_main()