"vscode:/vscode.git/clone" did not exist on "08d18a47f245e6befac29df9dc4346ccb7177c2f"
singleprocess_train.py 11 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import collections
import itertools
import os
import math
import torch

from fairseq import criterions, data, models, options, progress_bar
Myle Ott's avatar
Myle Ott committed
16
from fairseq.fp16_trainer import FP16Trainer
Myle Ott's avatar
Myle Ott committed
17
from fairseq.trainer import Trainer
Myle Ott's avatar
Myle Ott committed
18
from fairseq.meters import AverageMeter, StopwatchMeter
Myle Ott's avatar
Myle Ott committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


def main(args):
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(
            args.data, splits, args.source_lang, args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(
            args.data, splits, args.source_lang, args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))

    # Build model and criterion
    model = models.build_model(args, dataset.src_dict, dataset.dst_dict)
    criterion = criterions.build_criterion(args, dataset.src_dict, dataset.dst_dict)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))

    # Build trainer
Myle Ott's avatar
Myle Ott committed
52
53
54
55
    if args.fp16:
        trainer = FP16Trainer(args, model, criterion)
    else:
        trainer = Trainer(args, model, criterion)
Myle Ott's avatar
Myle Ott committed
56
57
58
59
60
61
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    # Initialize dataloader
    train_dataloader = dataset.train_dataloader_generator(
        args.train_subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=(
            min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
            min(args.max_target_positions, trainer.get_model().max_decoder_positions())
        ),
        seed=args.seed,
        sample_without_replacement=args.sample_without_replacement,
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )

Myle Ott's avatar
Myle Ott committed
77
78
79
    # Load the latest checkpoint if one is available
    os.makedirs(args.save_dir, exist_ok=True)
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
80
81
82
83
84
85
    epoch = 1
    if os.path.isfile(checkpoint_path):
        extra_state = trainer.load_checkpoint(checkpoint_path)
        if extra_state is not None:
            epoch = extra_state['epoch']
            print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
Myle Ott's avatar
Myle Ott committed
86
            trainer.lr_step(epoch)
87
88
            for i in range(epoch):
                _ = next(train_dataloader)
Myle Ott's avatar
Myle Ott committed
89
90
            epoch += 1

Myle Ott's avatar
Myle Ott committed
91
92
93
94
    # Send a dummy batch to warm the caching allocator
    dummy_batch = data.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict)
    trainer.dummy_train_step(dummy_batch)

Myle Ott's avatar
Myle Ott committed
95
96
    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
Myle Ott's avatar
Myle Ott committed
97
    max_update = args.max_update or math.inf
Myle Ott's avatar
Myle Ott committed
98
99
100
101
102
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
103
        train(args, trainer, next(train_dataloader), epoch)
Myle Ott's avatar
Myle Ott committed
104
105

        # evaluate on validate set
106
        first_val_loss = None
Myle Ott's avatar
Myle Ott committed
107
108
109
110
        if epoch % args.validate_interval == 0:
            for k, subset in enumerate(args.valid_subset.split(',')):
                val_loss = validate(args, trainer, dataset, subset, epoch)
                if k == 0:
111
                    first_val_loss = val_loss
Myle Ott's avatar
Myle Ott committed
112

113
114
115
116
117
118
        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch, first_val_loss)

        # save checkpoint
        if not args.no_save and epoch % args.save_interval == 0:
            save_checkpoint(trainer, args, epoch, first_val_loss)
Myle Ott's avatar
Myle Ott committed
119
120

        epoch += 1
Myle Ott's avatar
Myle Ott committed
121
122
123

        if trainer.get_num_updates() >= max_update:
            break
Myle Ott's avatar
Myle Ott committed
124
125
126
127
128
    train_meter.stop()

    print('| done training in {:.1f} seconds'.format(train_meter.sum))


129
def train(args, trainer, itr, epoch):
Myle Ott's avatar
Myle Ott committed
130
131
132
133
134
135
136
137
138
139
140
141
142
    """Train the model for one epoch."""

    # Set seed based on args.seed and the epoch number so that we get
    # reproducible results when resuming from checkpoints
    seed = args.seed + epoch
    torch.manual_seed(seed)

    # reset training meters
    for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

Myle Ott's avatar
Myle Ott committed
143
144
145
146
147
148
    # update parameters every N batches
    if epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch - 1]
    else:
        update_freq = args.update_freq[-1]

Myle Ott's avatar
Myle Ott committed
149
    extra_meters = collections.defaultdict(lambda: AverageMeter())
Myle Ott's avatar
Myle Ott committed
150
    max_update = args.max_update or math.inf
151
152
153
    num_batches = len(itr)
    progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
    for i, sample in enumerate(progress):
Myle Ott's avatar
Myle Ott committed
154
        if i < num_batches - 1 and (i + 1) % update_freq > 0:
Sergey Edunov's avatar
Sergey Edunov committed
155
156
157
158
159
            # buffer updates according to --update-freq
            trainer.train_step(sample, update_params=False)
            continue
        else:
            log_output = trainer.train_step(sample, update_params=True)
Myle Ott's avatar
Myle Ott committed
160
161
162
163

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
Myle Ott's avatar
Myle Ott committed
164
            if k in ['loss', 'nll_loss', 'sample_size']:
Myle Ott's avatar
Myle Ott committed
165
                continue  # these are already logged above
Myle Ott's avatar
Myle Ott committed
166
167
168
169
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
Myle Ott's avatar
Myle Ott committed
170
171
172
            stats[k] = extra_meters[k].avg
        progress.log(stats)

Sergey Edunov's avatar
Sergey Edunov committed
173
        # ignore the first mini-batch in words-per-second calculation
174
        if i == 0:
Myle Ott's avatar
Myle Ott committed
175
            trainer.get_meter('wps').reset()
Myle Ott's avatar
Myle Ott committed
176

177
        if trainer.get_num_updates() >= max_update:
Myle Ott's avatar
Myle Ott committed
178
179
            break

Myle Ott's avatar
Myle Ott committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)


def get_training_stats(trainer):
    stats = collections.OrderedDict()
    stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg)
    if trainer.get_meter('train_nll_loss').count > 0:
        nll_loss = trainer.get_meter('train_nll_loss').avg
        stats['nll_loss'] = '{:.3f}'.format(nll_loss)
    else:
        nll_loss = trainer.get_meter('train_loss').avg
    stats['ppl'] = get_perplexity(nll_loss)
    stats['wps'] = round(trainer.get_meter('wps').avg)
    stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg)
    stats['wpb'] = round(trainer.get_meter('wpb').avg)
    stats['bsz'] = round(trainer.get_meter('bsz').avg)
    stats['num_updates'] = trainer.get_num_updates()
    stats['lr'] = trainer.get_lr()
    stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
    stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
204
    stats['oom'] = trainer.get_meter('oom').avg
Myle Ott's avatar
Myle Ott committed
205
206
207
    if trainer.get_meter('loss_scale') is not None:
        stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
    stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
Myle Ott's avatar
Myle Ott committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    return stats


def validate(args, trainer, dataset, subset, epoch):
    """Evaluate the model on the validation set and return the average loss."""

    # Initialize dataloader
    max_positions_valid = (
        trainer.get_model().max_encoder_positions(),
        trainer.get_model().max_decoder_positions(),
    )
    itr = dataset.eval_dataloader(
        subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions_valid,
        skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(
        args, itr, epoch,
        prefix='valid on \'{}\' subset'.format(subset),
        no_progress_bar='simple'
    )

    # reset validation loss meters
    for k in ['valid_loss', 'valid_nll_loss']:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
Myle Ott's avatar
Myle Ott committed
248
            if k in ['loss', 'nll_loss', 'sample_size']:
Myle Ott's avatar
Myle Ott committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
                continue
            extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    return stats['valid_loss']


def get_valid_stats(trainer):
    stats = collections.OrderedDict()
    stats['valid_loss'] = trainer.get_meter('valid_loss').avg
    if trainer.get_meter('valid_nll_loss').count > 0:
        nll_loss = trainer.get_meter('valid_nll_loss').avg
        stats['valid_nll_loss'] = nll_loss
    else:
        nll_loss = trainer.get_meter('valid_loss').avg
    stats['valid_ppl'] = get_perplexity(nll_loss)
    return stats


def get_perplexity(loss):
    try:
        return '{:.2f}'.format(math.pow(2, loss))
    except OverflowError:
        return float('inf')


282
def save_checkpoint(trainer, args, epoch, val_loss=None):
Myle Ott's avatar
Myle Ott committed
283
284
285
286
287
    extra_state = {
        'epoch': epoch,
        'val_loss': val_loss,
    }

288
289
    if not args.no_epoch_checkpoints:
        epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
Myle Ott's avatar
Myle Ott committed
290
291
        trainer.save_checkpoint(epoch_filename, extra_state)

292
293
294
295
296
297
    assert val_loss is not None
    if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
        save_checkpoint.best = val_loss
        best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
        trainer.save_checkpoint(best_filename, extra_state)

Myle Ott's avatar
Myle Ott committed
298
299
300
301
302
303
304
305
    last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
    trainer.save_checkpoint(last_filename, extra_state)


if __name__ == '__main__':
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
    main(args)