train.py 13.6 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Train a new model on one or across multiple GPUs.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

12
import collections
Myle Ott's avatar
Myle Ott committed
13
import itertools
14
15
import os
import math
16
17
import random

18
import torch
Sergey Edunov's avatar
Sergey Edunov committed
19

20
from fairseq import distributed_utils, options, progress_bar, tasks, utils
21
from fairseq.data import iterators
22
23
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
Sergey Edunov's avatar
Sergey Edunov committed
24

Myle Ott's avatar
Myle Ott committed
25

Myle Ott's avatar
Myle Ott committed
26
def main(args):
27
28
    if args.max_tokens is None:
        args.max_tokens = 6000
29
30
31
32
33
34
35
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

Myle Ott's avatar
Myle Ott committed
36
37
    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
38

Myle Ott's avatar
Myle Ott committed
39
    # Load dataset splits
Alexei Baevski's avatar
Alexei Baevski committed
40
    load_dataset_splits(task, ['train', 'valid'])
41

Myle Ott's avatar
Myle Ott committed
42
43
44
    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
45
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
Myle Ott's avatar
Myle Ott committed
46
    print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
47

48
49
50
51
52
53
54
55
    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
56
    oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)
57

58
    # Build trainer
59
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
60
61
62
63
64
65
66
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
67
    epoch_itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
68
        dataset=task.dataset(args.train_subset),
69
        max_tokens=args.max_tokens,
70
        max_sentences=args.max_sentences,
Myle Ott's avatar
Myle Ott committed
71
72
73
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
74
75
        seed=args.seed,
        num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
76
        shard_id=args.distributed_rank,
77
78
79
    )

    # Load the latest checkpoint if one is available
80
    if not load_checkpoint(args, trainer, epoch_itr):
81
        trainer.dummy_train_step([dummy_batch])
82
83
84
85
86
87
88

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
Myle Ott's avatar
Myle Ott committed
89
    valid_losses = [None]
90
    valid_subsets = args.valid_subset.split(',')
Myle Ott's avatar
Myle Ott committed
91
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
92
        # train for one epoch
Myle Ott's avatar
Myle Ott committed
93
        train(args, trainer, task, epoch_itr)
94

Myle Ott's avatar
Myle Ott committed
95
96
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
97
98

        # only use first validation loss to update the learning rate
Myle Ott's avatar
Myle Ott committed
99
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
100
101

        # save checkpoint
Myle Ott's avatar
Myle Ott committed
102
103
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
104
105
106
107
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))


Myle Ott's avatar
Myle Ott committed
108
def train(args, trainer, task, epoch_itr):
109
110
    """Train the model for one epoch."""

111
    # Update parameters every N batches
Myle Ott's avatar
Myle Ott committed
112
113
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
114
115
116
    else:
        update_freq = args.update_freq[-1]

117
    # Initialize data iterator
Myle Ott's avatar
Myle Ott committed
118
    itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
119
120
121
122
123
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

124
    extra_meters = collections.defaultdict(lambda: AverageMeter())
125
    first_valid = args.valid_subset.split(',')[0]
126
    max_update = args.max_update or math.inf
127
128
129
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
130
131
132
133
134
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
135
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
136
137
138
139
140
141
142
143
144
145
146
147
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

148
        num_updates = trainer.get_num_updates()
149
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
Myle Ott's avatar
Myle Ott committed
150
151
            valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
152
153

        if num_updates >= max_update:
154
155
156
157
158
159
160
161
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

Myle Ott's avatar
Myle Ott committed
162
    # reset training meters
163
164
165
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
Myle Ott's avatar
Myle Ott committed
166
167
168
169
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

def get_training_stats(trainer):
    stats = collections.OrderedDict()
    stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg)
    if trainer.get_meter('train_nll_loss').count > 0:
        nll_loss = trainer.get_meter('train_nll_loss').avg
        stats['nll_loss'] = '{:.3f}'.format(nll_loss)
    else:
        nll_loss = trainer.get_meter('train_loss').avg
    stats['ppl'] = get_perplexity(nll_loss)
    stats['wps'] = round(trainer.get_meter('wps').avg)
    stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg)
    stats['wpb'] = round(trainer.get_meter('wpb').avg)
    stats['bsz'] = round(trainer.get_meter('bsz').avg)
    stats['num_updates'] = trainer.get_num_updates()
    stats['lr'] = trainer.get_lr()
    stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
    stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
    stats['oom'] = trainer.get_meter('oom').avg
    if trainer.get_meter('loss_scale') is not None:
        stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
    stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
Myle Ott's avatar
Myle Ott committed
192
    stats['train_wall'] = round(trainer.get_meter('train_wall').sum)
193
194
195
    return stats


Myle Ott's avatar
Myle Ott committed
196
def validate(args, trainer, task, epoch_itr, subsets):
197
198
199
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
Myle Ott's avatar
Myle Ott committed
200
        # Initialize data iterator
201
        itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
202
            dataset=task.dataset(subset),
203
204
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
205
206
207
208
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
Myle Ott's avatar
Myle Ott committed
209
210
211
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
212
            num_shards=args.distributed_world_size,
Myle Ott's avatar
Myle Ott committed
213
214
            shard_id=args.distributed_rank,
        ).next_epoch_itr(shuffle=False)
215
        progress = progress_bar.build_progress_bar(
Myle Ott's avatar
Myle Ott committed
216
            args, itr, epoch_itr.epoch,
217
218
219
220
221
222
223
224
225
226
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
Myle Ott's avatar
Myle Ott committed
227

228
229
230
231
        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
232
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
233
234
                    continue
                extra_meters[k].update(v)
235

236
237
238
239
240
        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats)
241

242
243
        valid_losses.append(stats['valid_loss'])
    return valid_losses
244
245
246
247
248
249
250
251


def get_valid_stats(trainer):
    stats = collections.OrderedDict()
    stats['valid_loss'] = trainer.get_meter('valid_loss').avg
    if trainer.get_meter('valid_nll_loss').count > 0:
        nll_loss = trainer.get_meter('valid_nll_loss').avg
        stats['valid_nll_loss'] = nll_loss
252
    else:
253
254
        nll_loss = trainer.get_meter('valid_loss').avg
    stats['valid_ppl'] = get_perplexity(nll_loss)
Myle Ott's avatar
Nits  
Myle Ott committed
255
256
257
    stats['num_updates'] = trainer.get_num_updates()
    if hasattr(save_checkpoint, 'best'):
        stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
258
259
260
261
262
263
264
265
266
267
    return stats


def get_perplexity(loss):
    try:
        return '{:.2f}'.format(math.pow(2, loss))
    except OverflowError:
        return float('inf')


Myle Ott's avatar
Myle Ott committed
268
269
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
270
        return
Myle Ott's avatar
Myle Ott committed
271
272
    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
273
274
275
276
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
Alexei Baevski's avatar
Alexei Baevski committed
277
278
            end_of_epoch and not args.no_epoch_checkpoints and
            epoch % args.save_interval == 0
279
280
    )
    checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
Alexei Baevski's avatar
Alexei Baevski committed
281
282
            not end_of_epoch and args.save_interval_updates > 0 and
            updates % args.save_interval_updates == 0
283
284
    )
    checkpoint_conds['checkpoint_best.pt'] = (
Alexei Baevski's avatar
Alexei Baevski committed
285
286
            val_loss is not None and
            (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
287
288
289
    )
    checkpoint_conds['checkpoint_last.pt'] = True  # keep this last so that it's a symlink

Myle Ott's avatar
Myle Ott committed
290
291
292
    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
293
    extra_state = {
Myle Ott's avatar
Myle Ott committed
294
        'train_iterator': epoch_itr.state_dict(),
295
296
        'val_loss': val_loss,
    }
Naman Goyal's avatar
Naman Goyal committed
297
298
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})
299

300
301
    checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
    if len(checkpoints) > 0:
302
303
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)
304
305
306
307

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
308
309
        for old_chk in checkpoints[args.keep_interval_updates:]:
            os.remove(old_chk)
310
311


Myle Ott's avatar
Myle Ott committed
312
313
def load_checkpoint(args, trainer, epoch_itr):
    """Load a checkpoint and replay dataloader to match."""
314
315
316
    os.makedirs(args.save_dir, exist_ok=True)
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    if os.path.isfile(checkpoint_path):
317
318
        extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
                                              eval(args.optimizer_overrides))
319
        if extra_state is not None:
Myle Ott's avatar
Myle Ott committed
320
321
322
323
324
            # replay train iterator to match checkpoint
            epoch_itr.load_state_dict(extra_state['train_iterator'])

            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
alexeib's avatar
alexeib committed
325

Myle Ott's avatar
Myle Ott committed
326
327
            trainer.lr_step(epoch_itr.epoch)
            trainer.lr_step_update(trainer.get_num_updates())
328
329
            if 'best' in extra_state:
                save_checkpoint.best = extra_state['best']
330
331
        return True
    return False
332

333

Alexei Baevski's avatar
Alexei Baevski committed
334
def load_dataset_splits(task, splits):
Myle Ott's avatar
Myle Ott committed
335
    for split in splits:
Alexei Baevski's avatar
Alexei Baevski committed
336
337
338
339
340
341
342
343
344
345
346
        if split == 'train':
            task.load_dataset(split, combine=True)
        else:
            for k in itertools.count():
                split_k = split + (str(k) if k > 0 else '')
                try:
                    task.load_dataset(split_k, combine=False)
                except FileNotFoundError as e:
                    if k > 0:
                        break
                    raise e
Sergey Edunov's avatar
Sergey Edunov committed
347

Myle Ott's avatar
Myle Ott committed
348

Sergey Edunov's avatar
Sergey Edunov committed
349
if __name__ == '__main__':
Myle Ott's avatar
Myle Ott committed
350
351
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
352
353
354

    if args.distributed_port > 0 or args.distributed_init_method is not None:
        from distributed_train import main as distributed_main
355

356
357
358
        distributed_main(args)
    elif args.distributed_world_size > 1:
        from multiprocessing_train import main as multiprocessing_main
359

360
361
362
363
364
365
        # Set distributed training parameters for a single node.
        args.distributed_world_size = torch.cuda.device_count()
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
        args.distributed_port = port + 1

366
367
368
        multiprocessing_main(args)
    else:
        main(args)