trainer.py 17.5 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

"""
Myle Ott's avatar
Myle Ott committed
9
Train a network across multiple GPUs.
Myle Ott's avatar
Myle Ott committed
10
11
"""

Peng-Jen Chen's avatar
Peng-Jen Chen committed
12
from collections import OrderedDict
Sergey Edunov's avatar
Sergey Edunov committed
13
from itertools import chain
Myle Ott's avatar
Myle Ott committed
14
import math
Myle Ott's avatar
Myle Ott committed
15
import os
16
import sys
Myle Ott's avatar
Myle Ott committed
17

Myle Ott's avatar
Myle Ott committed
18
19
import torch

Myle Ott's avatar
Myle Ott committed
20
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
Myle Ott's avatar
Myle Ott committed
21
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
Myle Ott's avatar
Myle Ott committed
22
23
24
25
from fairseq.optim import lr_scheduler


class Trainer(object):
Myle Ott's avatar
Myle Ott committed
26
    """Main class for data parallel training.
Myle Ott's avatar
Myle Ott committed
27

28
29
30
31
32
    This class supports synchronous distributed data parallel training,
    where multiple workers each have a full model replica and gradients
    are accumulated across workers before each update. We use
    :class:`~torch.nn.parallel.DistributedDataParallel` to handle
    communication of the gradients across workers.
Myle Ott's avatar
Myle Ott committed
33
34
    """

35
    def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
Myle Ott's avatar
Myle Ott committed
36
        self.args = args
37
        self.task = task
Myle Ott's avatar
Myle Ott committed
38
39

        # copy model and criterion to current device
Myle Ott's avatar
Myle Ott committed
40
41
42
        self.criterion = criterion
        self._model = model
        self.cuda = torch.cuda.is_available() and not args.cpu
43
        if args.fp16:
Myle Ott's avatar
Myle Ott committed
44
45
46
47
            self._model = self._model.half()
        if self.cuda:
            self.criterion = self.criterion.cuda()
            self._model = self._model.cuda()
Myle Ott's avatar
Myle Ott committed
48

49
        self._dummy_batch = dummy_batch
50
        self._oom_batch = oom_batch
Myle Ott's avatar
Myle Ott committed
51
52

        self._lr_scheduler = None
53
54
55
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
56
        self._prev_grad_norm = None
57
58
59
60
61
        self._wrapped_model = None

        self.init_meters(args)

    def init_meters(self, args):
Myle Ott's avatar
Myle Ott committed
62
63
64
65
66
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
Myle Ott's avatar
Myle Ott committed
67
68
69
70
71
72
73
        self.meters['wps'] = TimeMeter()       # words per second
        self.meters['ups'] = TimeMeter()       # updates per second
        self.meters['wpb'] = AverageMeter()    # words per batch
        self.meters['bsz'] = AverageMeter()    # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()   # % of updates clipped
        self.meters['oom'] = AverageMeter()    # out of memory
74
75
        if args.fp16:
            self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
Myle Ott's avatar
Myle Ott committed
76
        self.meters['wall'] = TimeMeter()      # wall time in seconds
Myle Ott's avatar
Myle Ott committed
77
        self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
Myle Ott's avatar
Myle Ott committed
78

79
80
81
82
83
84
85
86
87
88
    @property
    def model(self):
        if self._wrapped_model is None:
            if self.args.distributed_world_size > 1:
                self._wrapped_model = models.DistributedFairseqModel(
                    self.args, self._model,
                )
            else:
                self._wrapped_model = self._model
        return self._wrapped_model
Myle Ott's avatar
Myle Ott committed
89
90
91
92
93
94

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer
Myle Ott's avatar
Myle Ott committed
95

Myle Ott's avatar
Myle Ott committed
96
97
98
    @property
    def lr_scheduler(self):
        if self._lr_scheduler is None:
Myle Ott's avatar
Myle Ott committed
99
            self._build_optimizer()  # this will initialize self._lr_scheduler
Myle Ott's avatar
Myle Ott committed
100
101
        return self._lr_scheduler

Myle Ott's avatar
Myle Ott committed
102
    def _build_optimizer(self):
Myle Ott's avatar
Myle Ott committed
103
        params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
104
        if self.args.fp16:
Myle Ott's avatar
Myle Ott committed
105
            if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
106
107
                print('| WARNING: your device does NOT support faster training with --fp16, '
                      'please switch to FP32 which is likely to be faster')
Myle Ott's avatar
Myle Ott committed
108
109
110
111
            if self.args.memory_efficient_fp16:
                self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
            else:
                self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
112
        else:
Myle Ott's avatar
Myle Ott committed
113
            if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
114
                print('| NOTICE: your device may support faster training with --fp16')
Myle Ott's avatar
Myle Ott committed
115
            self._optimizer = optim.build_optimizer(self.args, params)
Myle Ott's avatar
Myle Ott committed
116

Myle Ott's avatar
Myle Ott committed
117
118
119
120
        # We should initialize the learning rate scheduler immediately after
        # building the optimizer, so that the initial learning rate is set.
        self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)

Myle Ott's avatar
Myle Ott committed
121
122
    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
123
        if distributed_utils.is_master(self.args):  # only save one checkpoint
Myle Ott's avatar
Myle Ott committed
124
            extra_state['train_meters'] = self.meters
Myle Ott's avatar
Myle Ott committed
125
            checkpoint_utils.save_state(
126
                filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
Myle Ott's avatar
Nits  
Myle Ott committed
127
128
                self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
            )
Myle Ott's avatar
Myle Ott committed
129

130
    def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
Myle Ott's avatar
Myle Ott committed
131
        """Load all training state from a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        extra_state, self._optim_history, last_optim_state = None, [], None

        if os.path.exists(filename):
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(state['model'], strict=True)
            except Exception:
                raise Exception(
                    'Cannot load model parameters from checkpoint, '
                    'please ensure that the architectures match.'
                )

            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
            last_optim_state = state['last_optimizer_state']

150
        if last_optim_state is not None and not reset_optimizer:
Myle Ott's avatar
Myle Ott committed
151
            # rebuild optimizer after loading model, since params may have changed
Myle Ott's avatar
Myle Ott committed
152
            self._build_optimizer()
Myle Ott's avatar
Myle Ott committed
153

154
155
156
            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
157
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
158
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
159
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
160
161
162
163
164
165

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

            self._num_updates = last_optim['num_updates']
Myle Ott's avatar
Myle Ott committed
166

Myle Ott's avatar
Myle Ott committed
167
        if extra_state is not None and 'train_meters' in extra_state:
168
            self.meters.update(extra_state['train_meters'])
Myle Ott's avatar
Myle Ott committed
169
170
            del extra_state['train_meters']

Myle Ott's avatar
Myle Ott committed
171
172
173
174
175
            # reset TimeMeters, since their start times don't make sense anymore
            for meter in self.meters.values():
                if isinstance(meter, TimeMeter):
                    meter.reset()

Myle Ott's avatar
Myle Ott committed
176
177
        return extra_state

178
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
179
        """Do forward, backward and parameter update."""
180
        self._set_seed()
181
        self.model.train()
Myle Ott's avatar
Myle Ott committed
182
        self.criterion.train()
183
184
        self.zero_grad()

Myle Ott's avatar
Myle Ott committed
185
186
187
        if not dummy_batch:
            self.meters['train_wall'].start()

Sergey Edunov's avatar
Sergey Edunov committed
188
        # forward and backward pass
189
190
191
192
193
194
195
196
197
198
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False
Myle Ott's avatar
Myle Ott committed
199

200
            try:
201
202
203
204
                if self.args.distributed_world_size > 1:
                    # Whenever *samples* contains more than one mini-batch, we
                    # want to accumulate gradients locally and only call
                    # all-reduce in the last backwards pass. Currently the
Myle Ott's avatar
Myle Ott committed
205
                    # *accumulate_grads* flag is only supported by
206
207
208
209
210
211
                    # LegacyDistributedDataParallel.
                    if i < len(samples) - 1:
                        self.model.accumulate_grads = True
                    else:
                        self.model.accumulate_grads = False

Peng-Jen Chen's avatar
Peng-Jen Chen committed
212
213
214
215
                # forward and backward
                loss, sample_size, logging_output = self.task.train_step(
                    sample, self.model, self.criterion, self.optimizer,
                    ignore_grad
216
217
218
219
220
221
222
                )

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
223
224
225
226
227
228
229
230
231
232
233
234
                    msg = (
                        '| WARNING: ran out of memory with exception: '
                        + '{};'.format(e)
                        + '\n Skipping batch'
                    )
                    # TODO: print should really go to logger, this print goes
                    # to stdout, which is buffered, which in many case is not
                    # printed out if another exception happens
                    # print(msg)
                    print(msg, file=sys.stderr)
                    if raise_oom:
                        raise ValueError(msg)
235
236
237
238
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e
Myle Ott's avatar
Myle Ott committed
239

240
241
242
        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

243
244
        if dummy_batch:
            return None
Myle Ott's avatar
Myle Ott committed
245
246
247

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
248
249
250
251
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
252
253
254
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)
Myle Ott's avatar
Myle Ott committed
255
256
257
258
            assert (
                all(norm == prev_norms[0] for norm in prev_norms)
                or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
            ), 'Fatal error: gradients are inconsistent between workers'
Myle Ott's avatar
Myle Ott committed
259

260
        self.meters['oom'].update(ooms, len(samples))
261
        if ooms == self.args.distributed_world_size * len(samples):
262
            print('| WARNING: OOM in all workers, skipping update')
Myle Ott's avatar
Myle Ott committed
263
264
265
            self.zero_grad()
            return None

266
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
267
268
269
        logging_output = self.task.aggregate_logging_outputs(
            logging_outputs, self.criterion
        )
Myle Ott's avatar
Myle Ott committed
270
        sample_size = self.task.grad_denom(sample_sizes, self.criterion)
271
272
273
274
275

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
Peng-Jen Chen's avatar
Peng-Jen Chen committed
276
            ).format(self.task.__class__.__name__))
Myle Ott's avatar
Myle Ott committed
277
278

        try:
279
280
281
282
283
            # normalize grads by sample size
            self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
284
            self._prev_grad_norm = grad_norm
285
286
287
288
289
290
291

            # take an optimization step
            self.optimizer.step()
            self._num_updates += 1

            # update learning rate
            self.lr_scheduler.step_update(self._num_updates)
Myle Ott's avatar
Myle Ott committed
292

293
294
295
            # task specific update per step
            self.task.update_step(self._num_updates)

Myle Ott's avatar
Myle Ott committed
296
            # update meters
297
298
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
Myle Ott's avatar
Myle Ott committed
299
300
301
302
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
303
304
305
306
307
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
308
309
310
311
            if 'train_acc' in self.meters:
                self.meters['train_acc'].update(
                    logging_output.get('acc', 0), sample_size)

312
313
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
314
315
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
316
317
            self.zero_grad()
            logging_output = None
Myle Ott's avatar
Myle Ott committed
318

319
320
321
        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
Myle Ott's avatar
Myle Ott committed
322

323
        self.meters['train_wall'].stop()
Myle Ott's avatar
Myle Ott committed
324

325
        return logging_output
Myle Ott's avatar
Myle Ott committed
326

327
    def valid_step(self, sample, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
328
        """Do forward pass in evaluation mode."""
329
        with torch.no_grad():
Myle Ott's avatar
Myle Ott committed
330
            self.model.eval()
Myle Ott's avatar
Myle Ott committed
331
            self.criterion.eval()
Myle Ott's avatar
Myle Ott committed
332

333
334
335
            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
Myle Ott's avatar
Myle Ott committed
336
337
338
339
                ignore_results = True
            else:
                ignore_results = False

340
            try:
Peng-Jen Chen's avatar
Peng-Jen Chen committed
341
342
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion
343
344
345
346
347
348
                )
            except RuntimeError as e:
                if 'out of memory' in str(e) and not raise_oom:
                    print('| WARNING: ran out of memory, retrying batch')
                    for p in self.model.parameters():
                        if p.grad is not None:
Myle Ott's avatar
Myle Ott committed
349
                            p.grad = None  # free some memory
Myle Ott's avatar
Myle Ott committed
350
351
                    if self.cuda:
                        torch.cuda.empty_cache()
352
353
354
                    return self.valid_step(sample, raise_oom=True)
                else:
                    raise e
Sergey Edunov's avatar
Sergey Edunov committed
355

Myle Ott's avatar
Myle Ott committed
356
357
358
            if ignore_results:
                logging_output, sample_size = {}, 0

359
        # gather logging outputs from all replicas
Sergey Edunov's avatar
Sergey Edunov committed
360
        if self.args.distributed_world_size > 1:
361
362
            logging_output, sample_size = zip(*distributed_utils.all_gather_list(
                [logging_output, sample_size],
Sergey Edunov's avatar
Sergey Edunov committed
363
            ))
364
365
            logging_output = list(logging_output)
            sample_size = list(sample_size)
Sergey Edunov's avatar
Sergey Edunov committed
366
        else:
367
368
            logging_output = [logging_output]
            sample_size = [sample_size]
Myle Ott's avatar
Myle Ott committed
369

370
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
371
372
373
374
375
376
        logging_output = self.task.aggregate_logging_outputs(
            logging_output, self.criterion
        )
        sample_size = self.task.grad_denom(
            sample_size, self.criterion
        )
Myle Ott's avatar
Myle Ott committed
377

378
379
380
        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)
        self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
381
382
383
384
        if 'valid_acc' in self.meters:
            self.meters['valid_acc'].update(
                logging_output.get('acc', 0), sample_size)

385
386
        if 'nll_loss' in logging_output:
            self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
387

388
        return logging_output
Myle Ott's avatar
Myle Ott committed
389

Myle Ott's avatar
Myle Ott committed
390
391
    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
392
        self.train_step(dummy_batch, dummy_batch=True)
Myle Ott's avatar
Myle Ott committed
393
394
        self.zero_grad()

395
396
397
398
399
400
401
402
403
    def handle_ooms(self, number_of_ooms):
        """
        c10d accumulates/syncs gradients between gpus during backward pass.
        In case of OOMs, gpus may fail to sync, so we manually iterate
        extra to make sure each gpu makes same number of iterations.
        """
        for _ in range(number_of_ooms):
            self.train_step([self._oom_batch], True)

Myle Ott's avatar
Myle Ott committed
404
405
406
    def zero_grad(self):
        self.optimizer.zero_grad()

Myle Ott's avatar
Myle Ott committed
407
408
409
410
    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
        return self.lr_scheduler.step(epoch, val_loss)

Myle Ott's avatar
Myle Ott committed
411
412
413
414
    def lr_step_update(self, num_updates):
        """Update the learning rate after each update."""
        return self.lr_scheduler.step_update(num_updates)

Myle Ott's avatar
Myle Ott committed
415
416
417
418
419
    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
420
421
        """Get the (non-wrapped) model instance."""
        return self._model
Myle Ott's avatar
Myle Ott committed
422
423
424
425
426
427
428
429
430
431
432

    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

Myle Ott's avatar
Myle Ott committed
433
    def _prepare_sample(self, sample):
Myle Ott's avatar
Myle Ott committed
434
435
        if sample is None or len(sample) == 0:
            return None
Myle Ott's avatar
Myle Ott committed
436
437
438
        if self.cuda:
            sample = utils.move_to_cuda(sample)
        return sample
439
440
441
442
443
444
445
446

    def _set_seed(self):
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        if self.cuda:
            torch.cuda.manual_seed(seed)