"vscode:/vscode.git/clone" did not exist on "1524122532927dfd8ff80b0899344e696a7ab47a"
train.py 11.6 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Train a new model on one or across multiple GPUs.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

12
13
import collections
import math
Myle Ott's avatar
Myle Ott committed
14
import os
15
16
import random

17
import torch
Sergey Edunov's avatar
Sergey Edunov committed
18

Myle Ott's avatar
Myle Ott committed
19
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
20
from fairseq.data import iterators
21
22
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
Sergey Edunov's avatar
Sergey Edunov committed
23

Myle Ott's avatar
Myle Ott committed
24

25
def main(args, init_distributed=False):
Myle Ott's avatar
Myle Ott committed
26
    utils.import_user_module(args)
27

28
29
    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
30

31
    # Initialize CUDA and distributed training
Myle Ott's avatar
Myle Ott committed
32
33
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
34
    torch.manual_seed(args.seed)
35
36
37
38
39
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    # Print args
    print(args)
40

Myle Ott's avatar
Myle Ott committed
41
42
    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
43

Myle Ott's avatar
Myle Ott committed
44
    # Load dataset splits
Naman Goyal's avatar
Naman Goyal committed
45
46
47
    task.load_dataset(args.train_subset, combine=True, epoch=0)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=True, epoch=0)
48

Myle Ott's avatar
Myle Ott committed
49
50
51
    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
52
    print(model)
53
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
54
55
56
57
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
58
59

    # Build trainer
Myle Ott's avatar
Myle Ott committed
60
    trainer = Trainer(args, task, model, criterion)
61
62
63
64
65
66
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

Naman Goyal's avatar
Naman Goyal committed
67
68
69
70
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
71
    # Initialize dataloader
72
    epoch_itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
73
        dataset=task.dataset(args.train_subset),
74
        max_tokens=args.max_tokens,
75
        max_sentences=args.max_sentences,
Naman Goyal's avatar
Naman Goyal committed
76
        max_positions=max_positions,
Myle Ott's avatar
Myle Ott committed
77
        ignore_invalid_inputs=True,
78
        required_batch_size_multiple=args.required_batch_size_multiple,
79
80
        seed=args.seed,
        num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
81
        shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
82
        num_workers=args.num_workers,
83
84
85
    )

    # Load the latest checkpoint if one is available
86
87
    checkpoint_utils.load_checkpoint(
        args, trainer, epoch_itr, max_positions, task)
88
89
90
91
92
93
94

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
Myle Ott's avatar
Myle Ott committed
95
    valid_losses = [None]
96
    valid_subsets = args.valid_subset.split(',')
Myle Ott's avatar
Myle Ott committed
97
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
98
        # train for one epoch
Myle Ott's avatar
Myle Ott committed
99
        train(args, trainer, task, epoch_itr)
100

Myle Ott's avatar
Myle Ott committed
101
102
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
103
104

        # only use first validation loss to update the learning rate
Myle Ott's avatar
Myle Ott committed
105
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
106
107

        # save checkpoint
Myle Ott's avatar
Myle Ott committed
108
        if epoch_itr.epoch % args.save_interval == 0:
109
110
            checkpoint_utils.save_checkpoint(
                args, trainer, epoch_itr, valid_losses[0])
Naman Goyal's avatar
Naman Goyal committed
111

112
        epoch_itr = checkpoint_utils.reload_train(args, epoch_itr, max_positions, task)
113
114
115
116
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))


Myle Ott's avatar
Myle Ott committed
117
def train(args, trainer, task, epoch_itr):
118
    """Train the model for one epoch."""
119
    # Update parameters every N batches
Myle Ott's avatar
Myle Ott committed
120
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
Myle Ott's avatar
Myle Ott committed
121
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
Myle Ott's avatar
Myle Ott committed
122
123
124
125
126
127

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
128
129
130
131
132
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

133
    extra_meters = collections.defaultdict(lambda: AverageMeter())
134
    valid_subsets = args.valid_subset.split(',')
135
    max_update = args.max_update or math.inf
136
137
138
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
139
140
141
142
143
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
144
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
145
146
147
148
149
150
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
Myle Ott's avatar
Myle Ott committed
151
        progress.log(stats, tag='train', step=stats['num_updates'])
152
153
154
155
156

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

157
        num_updates = trainer.get_num_updates()
158
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
159
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
160
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
161
162

        if num_updates >= max_update:
163
164
165
166
167
168
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
169
    progress.print(stats, tag='train', step=stats['num_updates'])
170

Myle Ott's avatar
Myle Ott committed
171
    # reset training meters
172
173
174
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
Myle Ott's avatar
Myle Ott committed
175
176
177
178
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

179
180
181

def get_training_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
182
    stats['loss'] = trainer.get_meter('train_loss')
183
    if trainer.get_meter('train_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
184
185
        nll_loss = trainer.get_meter('train_nll_loss')
        stats['nll_loss'] = nll_loss
186
    else:
Myle Ott's avatar
Myle Ott committed
187
        nll_loss = trainer.get_meter('train_loss')
188
    stats['ppl'] = utils.get_perplexity(nll_loss.avg)
Myle Ott's avatar
Myle Ott committed
189
190
191
192
    stats['wps'] = trainer.get_meter('wps')
    stats['ups'] = trainer.get_meter('ups')
    stats['wpb'] = trainer.get_meter('wpb')
    stats['bsz'] = trainer.get_meter('bsz')
193
194
    stats['num_updates'] = trainer.get_num_updates()
    stats['lr'] = trainer.get_lr()
Myle Ott's avatar
Myle Ott committed
195
196
197
    stats['gnorm'] = trainer.get_meter('gnorm')
    stats['clip'] = trainer.get_meter('clip')
    stats['oom'] = trainer.get_meter('oom')
198
    if trainer.get_meter('loss_scale') is not None:
Myle Ott's avatar
Myle Ott committed
199
        stats['loss_scale'] = trainer.get_meter('loss_scale')
200
    stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
Myle Ott's avatar
Myle Ott committed
201
    stats['train_wall'] = trainer.get_meter('train_wall')
202
203
204
    return stats


Myle Ott's avatar
Myle Ott committed
205
def validate(args, trainer, task, epoch_itr, subsets):
206
207
208
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
Myle Ott's avatar
Myle Ott committed
209
        # Initialize data iterator
210
        itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
211
            dataset=task.dataset(subset),
212
213
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
214
215
216
217
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
Myle Ott's avatar
Myle Ott committed
218
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
219
            required_batch_size_multiple=args.required_batch_size_multiple,
Myle Ott's avatar
Myle Ott committed
220
            seed=args.seed,
221
            num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
222
            shard_id=args.distributed_rank,
Myle Ott's avatar
Myle Ott committed
223
            num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
224
        ).next_epoch_itr(shuffle=False)
225
        progress = progress_bar.build_progress_bar(
Myle Ott's avatar
Myle Ott committed
226
            args, itr, epoch_itr.epoch,
227
228
229
230
231
232
233
234
235
236
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
Myle Ott's avatar
Myle Ott committed
237

238
239
240
241
        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
242
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
243
244
                    continue
                extra_meters[k].update(v)
245

246
247
248
249
        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
Myle Ott's avatar
Myle Ott committed
250
        progress.print(stats, tag=subset, step=trainer.get_num_updates())
251

Myle Ott's avatar
Myle Ott committed
252
        valid_losses.append(stats['loss'].avg)
253
    return valid_losses
254
255
256
257


def get_valid_stats(trainer):
    stats = collections.OrderedDict()
Myle Ott's avatar
Myle Ott committed
258
    stats['loss'] = trainer.get_meter('valid_loss')
259
    if trainer.get_meter('valid_nll_loss').count > 0:
Myle Ott's avatar
Myle Ott committed
260
261
        nll_loss = trainer.get_meter('valid_nll_loss')
        stats['nll_loss'] = nll_loss
262
    else:
Myle Ott's avatar
Myle Ott committed
263
        nll_loss = stats['loss']
264
    stats['ppl'] = utils.get_perplexity(nll_loss.avg)
Myle Ott's avatar
Nits  
Myle Ott committed
265
    stats['num_updates'] = trainer.get_num_updates()
266
267
268
    if hasattr(checkpoint_utils.save_checkpoint, 'best'):
        stats['best_loss'] = min(
            checkpoint_utils.save_checkpoint.best, stats['loss'].avg)
269
270
271
    return stats


272
def distributed_main(i, args, start_rank=0):
Myle Ott's avatar
Myle Ott committed
273
274
    args.device_id = i
    if args.distributed_rank is None:  # torch.multiprocessing.spawn
275
276
        args.distributed_rank = start_rank + i
    main(args, init_distributed=True)
Myle Ott's avatar
Myle Ott committed
277
278


Myle Ott's avatar
Myle Ott committed
279
def cli_main():
Myle Ott's avatar
Myle Ott committed
280
281
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
282

Myle Ott's avatar
Myle Ott committed
283
284
    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)
285

Myle Ott's avatar
Myle Ott committed
286
287
    if args.distributed_init_method is not None:
        # distributed training
288
289
290
291
292
293
294
295
296
297
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
298
    elif args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
299
        # fallback for single node with multiple GPUs
300
        assert args.distributed_world_size <= torch.cuda.device_count()
301
302
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
Myle Ott's avatar
Myle Ott committed
303
        args.distributed_rank = None  # set based on device id
Myle Ott's avatar
Myle Ott committed
304
305
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
Myle Ott's avatar
Myle Ott committed
306
307
308
309
310
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
311
    else:
Myle Ott's avatar
Myle Ott committed
312
        # single GPU training
313
        main(args)
Myle Ott's avatar
Myle Ott committed
314
315
316
317


if __name__ == '__main__':
    cli_main()