train.py 8.75 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
#!/usr/bin/env python3
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import collections
import os
import torch
import math

from fairseq import bleu, data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator


def main():
    parser = options.get_parser('Trainer')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
                              help='maximum number of tokens in a batch')
    dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
                              choices=['train', 'valid', 'test'],
                              help='data subset to use for training (train, valid, test)')
    dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
                              help='comma separated list ofdata subsets '
                                   ' to use for validation (train, valid, valid1,test, test1)')
    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_model_args(parser)

37
    args = utils.parse_args_and_arch(parser)
Sergey Edunov's avatar
Sergey Edunov committed
38
39
40
41
42
43
44
45
46
47
48
    print(args)

    if args.no_progress_bar:
        progress_bar.enabled = False
        progress_bar.print_interval = args.log_interval

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    torch.manual_seed(args.seed)

    # Load dataset
49
    dataset = data.load_with_check(args.data, ['train', 'valid'], args.source_lang, args.target_lang)
Sergey Edunov's avatar
Sergey Edunov committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    for split in dataset.splits:
        print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    num_gpus = torch.cuda.device_count()

    print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))

    # Build model
    print('| model {}'.format(args.arch))
    model = utils.build_model(args, dataset)
    criterion = utils.build_criterion(args, dataset)

    # Start multiprocessing
71
    trainer = MultiprocessingTrainer(args, model, criterion)
Sergey Edunov's avatar
Sergey Edunov committed
72
73
74
75
76
77
78
79
80
81
82
83

    # Load the latest checkpoint if one is available
    epoch, batch_offset = trainer.load_checkpoint(os.path.join(args.save_dir, args.restore_file))

    # Train until the learning rate gets too small
    val_loss = None
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
84
        train(args, epoch, batch_offset, trainer, dataset, num_gpus)
Sergey Edunov's avatar
Sergey Edunov committed
85
86
87

        # evaluate on validate set
        for k, subset in enumerate(args.valid_subset.split(',')):
88
            val_loss = validate(args, epoch, trainer, dataset, subset, num_gpus)
Sergey Edunov's avatar
Sergey Edunov committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
            if k == 0:
                if not args.no_save:
                    # save checkpoint
                    trainer.save_checkpoint(args, epoch, 0, val_loss)
                # only use first validation loss to update the learning schedule
                lr = trainer.lr_step(val_loss, epoch)

        epoch += 1
        batch_offset = 0
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))

    # Stop multiprocessing
    trainer.stop()


105
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
Sergey Edunov's avatar
Sergey Edunov committed
106
107
108
109
110
    """Train the model for one epoch."""

    itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
                             max_tokens=args.max_tokens, seed=args.seed, epoch=epoch,
                             max_positions=args.max_positions,
111
112
                             sample_without_replacement=args.sample_without_replacement,
                             skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
Sergey Edunov's avatar
Sergey Edunov committed
113
114
115
116
117
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()    # sentences per batch
    wpb_meter = AverageMeter()    # words per batch
    wps_meter = TimeMeter()       # words per second
    clip_meter = AverageMeter()   # % of updates clipped
Myle Ott's avatar
Myle Ott committed
118
    extra_meters = collections.defaultdict(lambda: AverageMeter())
Sergey Edunov's avatar
Sergey Edunov committed
119
120
121
122
123

    desc = '| epoch {:03d}'.format(epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
Myle Ott's avatar
Myle Ott committed
124
125
126
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix
Sergey Edunov's avatar
Sergey Edunov committed
127
128
129
130
131
132
133

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
Myle Ott's avatar
Myle Ott committed
134
135
136
137
138
139
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
Sergey Edunov's avatar
Sergey Edunov committed
140
141
142
143
144
145
146
147

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
                ('wps', '{:5d}'.format(round(wps_meter.avg))),
                ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
Myle Ott's avatar
Myle Ott committed
148
            ] + extra_postfix), refresh=False)
Sergey Edunov's avatar
Sergey Edunov committed
149
150
151
152
153
154
155

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                trainer.save_checkpoint(args, epoch, i + 1)

Myle Ott's avatar
Myle Ott committed
156
157
158
159
160
161
162
163
164
165
166
        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
            loss_meter.avg, math.pow(2, loss_meter.avg))
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
            round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg))
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
            round(bsz_meter.avg), lr, clip_meter.avg * 100)
        fmt += ''.join(
            ' | {} {:.4f}'.format(k, meter.avg)
            for k, meter in extra_meters.items()
        )
        t.write(fmt)
Sergey Edunov's avatar
Sergey Edunov committed
167
168


169
def validate(args, epoch, trainer, dataset, subset, ngpus):
Sergey Edunov's avatar
Sergey Edunov committed
170
171
172
173
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.dataloader(subset, batch_size=None,
                             max_tokens=args.max_tokens,
174
175
                             max_positions=args.max_positions,
                             skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
Sergey Edunov's avatar
Sergey Edunov committed
176
    loss_meter = AverageMeter()
Myle Ott's avatar
Myle Ott committed
177
    extra_meters = collections.defaultdict(lambda: AverageMeter())
Sergey Edunov's avatar
Sergey Edunov committed
178
179
180
181

    desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
    with progress_bar(itr, desc, leave=False) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
Myle Ott's avatar
Myle Ott committed
182
183
184
185
            loss_dict = trainer.valid_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

Sergey Edunov's avatar
Sergey Edunov committed
186
187
            ntokens = sum(s['ntokens'] for s in sample)
            loss_meter.update(loss, ntokens)
Myle Ott's avatar
Myle Ott committed
188
189
190
191
192
193
194
195
196

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f}'.format(loss_meter.avg)),
            ] + extra_postfix), refresh=False)
Sergey Edunov's avatar
Sergey Edunov committed
197
198

        val_loss = loss_meter.avg
Myle Ott's avatar
Myle Ott committed
199
200
201
202
203
204
205
        fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
            val_loss, math.pow(2, val_loss))
        fmt += ''.join(
            ' | {} {:.4f}'.format(k, meter.avg)
            for k, meter in extra_meters.items()
        )
        t.write(fmt)
Sergey Edunov's avatar
Sergey Edunov committed
206
207
208
209
210
211
212

    # update and return the learning rate
    return val_loss


if __name__ == '__main__':
    main()