trainer.py 24.3 KB
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates.
Myle Ott's avatar
Myle Ott committed
2
#
3
4
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
5
6

"""
Myle Ott's avatar
Myle Ott committed
7
Train a network across multiple GPUs.
Myle Ott's avatar
Myle Ott committed
8
9
"""

Peng-Jen Chen's avatar
Peng-Jen Chen committed
10
from collections import OrderedDict
Myle Ott's avatar
Myle Ott committed
11
import contextlib
Sergey Edunov's avatar
Sergey Edunov committed
12
from itertools import chain
Myle Ott's avatar
Myle Ott committed
13
import math
Myle Ott's avatar
Myle Ott committed
14
import os
15
import sys
Myle Ott's avatar
Myle Ott committed
16

Myle Ott's avatar
Myle Ott committed
17
18
import torch

Myle Ott's avatar
Myle Ott committed
19
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
Myle Ott's avatar
Myle Ott committed
20
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
Myle Ott's avatar
Myle Ott committed
21
22
23
24
from fairseq.optim import lr_scheduler


class Trainer(object):
Myle Ott's avatar
Myle Ott committed
25
    """Main class for data parallel training.
Myle Ott's avatar
Myle Ott committed
26

27
28
29
30
31
    This class supports synchronous distributed data parallel training,
    where multiple workers each have a full model replica and gradients
    are accumulated across workers before each update. We use
    :class:`~torch.nn.parallel.DistributedDataParallel` to handle
    communication of the gradients across workers.
Myle Ott's avatar
Myle Ott committed
32
33
    """

Myle Ott's avatar
Myle Ott committed
34
    def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=None):
Myle Ott's avatar
Myle Ott committed
35
        self.args = args
36
        self.task = task
Myle Ott's avatar
Myle Ott committed
37
38

        # copy model and criterion to current device
Jeff Cai's avatar
Jeff Cai committed
39
        self._criterion = criterion
Myle Ott's avatar
Myle Ott committed
40
41
        self._model = model
        self.cuda = torch.cuda.is_available() and not args.cpu
42
        if args.fp16:
Jeff Cai's avatar
Jeff Cai committed
43
            self._criterion = self._criterion.half()
Myle Ott's avatar
Myle Ott committed
44
45
            self._model = self._model.half()
        if self.cuda:
Jeff Cai's avatar
Jeff Cai committed
46
            self._criterion = self._criterion.cuda()
Myle Ott's avatar
Myle Ott committed
47
            self._model = self._model.cuda()
Myle Ott's avatar
Myle Ott committed
48

49
        self._dummy_batch = dummy_batch
Myle Ott's avatar
Myle Ott committed
50
        self._oom_batch = oom_batch or dummy_batch
Myle Ott's avatar
Myle Ott committed
51
52

        self._lr_scheduler = None
53
54
55
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
56
        self._prev_grad_norm = None
Jeff Cai's avatar
Jeff Cai committed
57
        self._wrapped_criterion = None
58
59
        self._wrapped_model = None

60
61
62
63
64
        # Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
        # It is less flexible and syncs only the default stats.
        self._all_reduce_list = [0.0] * 6
        self.fast_stat_sync = args.fast_stat_sync

65
66
67
        self.init_meters(args)

    def init_meters(self, args):
Myle Ott's avatar
Myle Ott committed
68
69
70
71
72
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
Myle Ott's avatar
Myle Ott committed
73
74
75
76
77
78
79
        self.meters['wps'] = TimeMeter()       # words per second
        self.meters['ups'] = TimeMeter()       # updates per second
        self.meters['wpb'] = AverageMeter()    # words per batch
        self.meters['bsz'] = AverageMeter()    # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()   # % of updates clipped
        self.meters['oom'] = AverageMeter()    # out of memory
80
81
        if args.fp16:
            self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
Myle Ott's avatar
Myle Ott committed
82
        self.meters['wall'] = TimeMeter()      # wall time in seconds
Myle Ott's avatar
Myle Ott committed
83
        self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
Myle Ott's avatar
Myle Ott committed
84

Jeff Cai's avatar
Jeff Cai committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    @property
    def criterion(self):
        if self._wrapped_criterion is None:
            if (
                utils.has_parameters(self._criterion)
                and self.args.distributed_world_size > 1
                and not self.args.use_bmuf
            ):
                self._wrapped_criterion = models.DistributedFairseqModel(
                    self.args, self._criterion
                )
            else:
                self._wrapped_criterion = self._criterion
        return self._wrapped_criterion

100
101
102
    @property
    def model(self):
        if self._wrapped_model is None:
Nayan Singhal's avatar
Nayan Singhal committed
103
            if self.args.distributed_world_size > 1 and not self.args.use_bmuf:
104
105
106
107
108
109
                self._wrapped_model = models.DistributedFairseqModel(
                    self.args, self._model,
                )
            else:
                self._wrapped_model = self._model
        return self._wrapped_model
Myle Ott's avatar
Myle Ott committed
110
111
112
113
114
115

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer
Myle Ott's avatar
Myle Ott committed
116

Myle Ott's avatar
Myle Ott committed
117
118
119
    @property
    def lr_scheduler(self):
        if self._lr_scheduler is None:
Myle Ott's avatar
Myle Ott committed
120
            self._build_optimizer()  # this will initialize self._lr_scheduler
Myle Ott's avatar
Myle Ott committed
121
122
        return self._lr_scheduler

Myle Ott's avatar
Myle Ott committed
123
    def _build_optimizer(self):
Jeff Cai's avatar
Jeff Cai committed
124
125
126
127
128
129
130
        params = list(
            filter(
                lambda p: p.requires_grad,
                chain(self.model.parameters(), self.criterion.parameters()),
            )
        )

131
        if self.args.fp16:
Myle Ott's avatar
Myle Ott committed
132
            if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
133
134
                print('| WARNING: your device does NOT support faster training with --fp16, '
                      'please switch to FP32 which is likely to be faster')
Myle Ott's avatar
Myle Ott committed
135
136
137
138
            if self.args.memory_efficient_fp16:
                self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
            else:
                self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
139
        else:
Myle Ott's avatar
Myle Ott committed
140
            if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
141
                print('| NOTICE: your device may support faster training with --fp16')
Myle Ott's avatar
Myle Ott committed
142
            self._optimizer = optim.build_optimizer(self.args, params)
Myle Ott's avatar
Myle Ott committed
143

Nayan Singhal's avatar
Nayan Singhal committed
144
        if self.args.use_bmuf:
Jeff Cai's avatar
Jeff Cai committed
145
            self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)
Nayan Singhal's avatar
Nayan Singhal committed
146

Myle Ott's avatar
Myle Ott committed
147
148
149
        # We should initialize the learning rate scheduler immediately after
        # building the optimizer, so that the initial learning rate is set.
        self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
150
        self._lr_scheduler.step_update(0)
Myle Ott's avatar
Myle Ott committed
151

Myle Ott's avatar
Myle Ott committed
152
153
    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
154
        if distributed_utils.is_master(self.args):  # only save one checkpoint
Myle Ott's avatar
Myle Ott committed
155
            extra_state['train_meters'] = self.meters
Myle Ott's avatar
Myle Ott committed
156
            checkpoint_utils.save_state(
Jeff Cai's avatar
Jeff Cai committed
157
                filename, self.args, self.get_model().state_dict(), self.get_criterion(),
Myle Ott's avatar
Myle Ott committed
158
                self.optimizer, self.lr_scheduler, self.get_num_updates(),
Myle Ott's avatar
Myle Ott committed
159
                self._optim_history, extra_state,
Myle Ott's avatar
Nits  
Myle Ott committed
160
            )
Myle Ott's avatar
Myle Ott committed
161

Myle Ott's avatar
Myle Ott committed
162
163
164
165
166
167
168
169
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
Myle Ott's avatar
Myle Ott committed
170
        """Load all training state from a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
171
172
        extra_state, self._optim_history, last_optim_state = None, [], None

173
174
175
        try:
            from fairseq.fb_pathmgr import fb_pathmgr
            bexists = fb_pathmgr.isfile(filename)
176
        except (ModuleNotFoundError, ImportError):
177
178
179
            bexists = os.path.exists(filename)

        if bexists:
Myle Ott's avatar
Myle Ott committed
180
181
182
183
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
184
                self.get_model().load_state_dict(state['model'], strict=True, args=self.args)
Jeff Cai's avatar
Jeff Cai committed
185
186
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state['criterion'], strict=True)
Myle Ott's avatar
Myle Ott committed
187
188
            except Exception:
                raise Exception(
189
190
                    'Cannot load model parameters from checkpoint {}; '
                    'please ensure that the architectures match.'.format(filename)
Myle Ott's avatar
Myle Ott committed
191
192
193
194
                )

            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
Myle Ott's avatar
Myle Ott committed
195
            last_optim_state = state.get('last_optimizer_state', None)
Myle Ott's avatar
Myle Ott committed
196

197
        if last_optim_state is not None and not reset_optimizer:
Myle Ott's avatar
Myle Ott committed
198
            # rebuild optimizer after loading model, since params may have changed
Myle Ott's avatar
Myle Ott committed
199
            self._build_optimizer()
Myle Ott's avatar
Myle Ott committed
200

201
202
            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
Jeff Cai's avatar
Jeff Cai committed
203
            assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
204
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
205
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
206
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
207
208
209
210
211

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

Myle Ott's avatar
Myle Ott committed
212
            self.set_num_updates(last_optim['num_updates'])
Myle Ott's avatar
Myle Ott committed
213

Myle Ott's avatar
Myle Ott committed
214
215
216
217
        if extra_state is not None:
            epoch = extra_state['train_iterator']['epoch']
            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                filename, epoch, self.get_num_updates()))
Myle Ott's avatar
Myle Ott committed
218

Myle Ott's avatar
Myle Ott committed
219
220
            self.lr_step(epoch)

Myle Ott's avatar
Myle Ott committed
221
            if 'train_meters' in extra_state and not reset_meters:
Myle Ott's avatar
Myle Ott committed
222
223
224
225
226
227
228
229
230
                self.meters.update(extra_state['train_meters'])
                del extra_state['train_meters']

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in self.meters.values():
                    if isinstance(meter, TimeMeter):
                        meter.reset()
        else:
            print('| no existing checkpoint found {}'.format(filename))
Myle Ott's avatar
Myle Ott committed
231

Myle Ott's avatar
Myle Ott committed
232
233
        return extra_state

234
    def get_train_iterator(self, epoch, combine=True, load_dataset=True, data_selector=None, shard_batch_itr=True):
Myle Ott's avatar
Myle Ott committed
235
        """Return an EpochBatchIterator over the training set for a given epoch."""
236
237
        if load_dataset:
            print('| loading train data for epoch {}'.format(epoch))
238
239
240
241
242
243
            self.task.load_dataset(
                self.args.train_subset,
                epoch=epoch,
                combine=combine,
                data_selector=data_selector,
            )
Myle Ott's avatar
Myle Ott committed
244
245
246
247
248
249
250
251
252
253
254
        return self.task.get_batch_iterator(
            dataset=self.task.dataset(self.args.train_subset),
            max_tokens=self.args.max_tokens,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(
                self.task.max_positions(),
                self.model.max_positions(),
            ),
            ignore_invalid_inputs=True,
            required_batch_size_multiple=self.args.required_batch_size_multiple,
            seed=self.args.seed,
255
256
            num_shards=self.args.distributed_world_size if shard_batch_itr else 1,
            shard_id=self.args.distributed_rank if shard_batch_itr else 0,
Myle Ott's avatar
Myle Ott committed
257
258
259
260
            num_workers=self.args.num_workers,
            epoch=epoch,
        )

261
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
262
        """Do forward, backward and parameter update."""
Myle Ott's avatar
Myle Ott committed
263
264
265
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

266
        self._set_seed()
267
        self.model.train()
Myle Ott's avatar
Myle Ott committed
268
        self.criterion.train()
269
270
        self.zero_grad()

Myle Ott's avatar
Myle Ott committed
271
272
273
        if not dummy_batch:
            self.meters['train_wall'].start()

Sergey Edunov's avatar
Sergey Edunov committed
274
        # forward and backward pass
275
276
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
Myle Ott's avatar
Myle Ott committed
277
            sample = self._prepare_sample(sample)
278
279
280
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
Myle Ott's avatar
Myle Ott committed
281
                sample = self._prepare_sample(self._dummy_batch)
282
283
284
                ignore_grad = True
            else:
                ignore_grad = False
Myle Ott's avatar
Myle Ott committed
285

Myle Ott's avatar
Myle Ott committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.args.distributed_world_size > 1
                    and hasattr(self.model, 'no_sync')
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

301
            try:
Myle Ott's avatar
Myle Ott committed
302
303
304
305
306
307
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad
                    )
308
309
310
311

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
312
313
314
315
316
317
318

                    if self.fast_stat_sync:
                        self._all_reduce_list[0] += sample_size
                        self._all_reduce_list[1] += logging_output.get('nsentences', 0.0)
                        self._all_reduce_list[2] += logging_output.get('loss', 0.0)
                        self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0)
                        self._all_reduce_list[4] += logging_output.get('ntokens', 0.0)
319
320
            except RuntimeError as e:
                if 'out of memory' in str(e):
Jerry Ma's avatar
Jerry Ma committed
321
                    self._log_oom(e)
322
                    if raise_oom:
Jerry Ma's avatar
Jerry Ma committed
323
324
325
                        raise e
                    print("| WARNING: attempting to recover from OOM in forward/backward pass",
                          file=sys.stderr)
326
327
328
329
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e
Myle Ott's avatar
Myle Ott committed
330

331
332
333
334
            if self.fast_stat_sync:
                self._all_reduce_list[5] += ooms


335
336
337
        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

338
339
        if dummy_batch:
            return None
Myle Ott's avatar
Myle Ott committed
340
341

        # gather logging outputs from all replicas
342
343
344
345
346
347
348
349
350
351
352
353
        if self.fast_stat_sync:
            # rework all_gather_list
            all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
            if self._sync_stats():
                torch.distributed.all_reduce(all_reduce_list_tensor)
            # Normalize loss and nll_loss by "sample_size"
            # and convert to log base 2
            all_reduce_list_tensor[2:4].div_(
                (
                    all_reduce_list_tensor[0:1] *
                    torch.log(torch.cuda.DoubleTensor([2]))
                )
Nayan Singhal's avatar
Nayan Singhal committed
354
            )
355
356
357
358
359
360
361
362
363
364
365
            self._all_reduce_list = all_reduce_list_tensor.tolist()
            logging_output = {}
            [
                sample_size,
                logging_output['nsentences'],
                logging_output['loss'],
                logging_output['nll_loss'],
                logging_output['ntokens'],
                ooms,
            ] = self._all_reduce_list
        elif self._sync_stats():
366
367
368
369
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
370
371
372
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)
Nayan Singhal's avatar
Nayan Singhal committed
373
374
375
376
377
378

            if not self.args.use_bmuf:
                assert (
                    all(norm == prev_norms[0] for norm in prev_norms)
                    or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
                ), 'Fatal error: gradients are inconsistent between workers'
Myle Ott's avatar
Myle Ott committed
379

380
        self.meters['oom'].update(ooms, len(samples))
381
        if ooms == self.args.distributed_world_size * len(samples):
382
            print('| WARNING: OOM in all workers, skipping update')
Myle Ott's avatar
Myle Ott committed
383
384
385
            self.zero_grad()
            return None

386
387
388
389
390
391
        if not self.fast_stat_sync:
            # aggregate logging outputs and sample sizes
            logging_output = self.task.aggregate_logging_outputs(
                logging_outputs, self.get_criterion()
            )
            sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
392
393
394
395
396

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
Peng-Jen Chen's avatar
Peng-Jen Chen committed
397
            ).format(self.task.__class__.__name__))
Myle Ott's avatar
Myle Ott committed
398
399

        try:
400
            # normalize grads by sample size
Nayan Singhal's avatar
Nayan Singhal committed
401
402
            if sample_size > 0:
                self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
403
404
405

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
406
            self._prev_grad_norm = grad_norm
407
408
409

            # take an optimization step
            self.optimizer.step()
Myle Ott's avatar
Myle Ott committed
410
            self.set_num_updates(self.get_num_updates() + 1)
Myle Ott's avatar
Myle Ott committed
411

412
413
414
            # task specific update per step
            self.task.update_step(self._num_updates)

Myle Ott's avatar
Myle Ott committed
415
            # update meters
416
417
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
Myle Ott's avatar
Myle Ott committed
418
419
420
421
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
422
423
424
425
426
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
427
428
429
430
            if 'train_acc' in self.meters:
                self.meters['train_acc'].update(
                    logging_output.get('acc', 0), sample_size)

431
432
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
433
434
435
436
437
438
439
440

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                    ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                     self.args.empty_cache_freq) == 0 and
                    torch.cuda.is_available() and
                    not self.args.cpu):
                torch.cuda.empty_cache()
Myle Ott's avatar
Myle Ott committed
441
442
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
443
444
            self.zero_grad()
            logging_output = None
Jerry Ma's avatar
Jerry Ma committed
445
446
447
448
449
        except RuntimeError as e:
            if 'out of memory' in str(e):
                self._log_oom(e)
                print('| ERROR: OOM during optimization, irrecoverable')
            raise e
Myle Ott's avatar
Myle Ott committed
450

451
452
453
        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
Myle Ott's avatar
Myle Ott committed
454

455
        self.clear_buffered_stats()
456
        self.meters['train_wall'].stop()
Myle Ott's avatar
Myle Ott committed
457

458
        return logging_output
Myle Ott's avatar
Myle Ott committed
459

460
    def valid_step(self, sample, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
461
        """Do forward pass in evaluation mode."""
462
        with torch.no_grad():
Myle Ott's avatar
Myle Ott committed
463
            self.model.eval()
Myle Ott's avatar
Myle Ott committed
464
            self.criterion.eval()
Myle Ott's avatar
Myle Ott committed
465

Myle Ott's avatar
Myle Ott committed
466
            sample = self._prepare_sample(sample)
467
            if sample is None:
Myle Ott's avatar
Myle Ott committed
468
                sample = self._prepare_sample(self._dummy_batch)
Myle Ott's avatar
Myle Ott committed
469
470
471
472
                ignore_results = True
            else:
                ignore_results = False

473
            try:
Peng-Jen Chen's avatar
Peng-Jen Chen committed
474
475
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion
476
477
                )
            except RuntimeError as e:
Jerry Ma's avatar
Jerry Ma committed
478
479
480
481
482
483
484
485
486
487
488
                if 'out of memory' in str(e):
                    self._log_oom(e)
                    if not raise_oom:
                        print('| WARNING: ran out of memory in validation step, retrying batch')
                        for p in self.model.parameters():
                            if p.grad is not None:
                                p.grad = None  # free some memory
                        if self.cuda:
                            torch.cuda.empty_cache()
                        return self.valid_step(sample, raise_oom=True)
                raise e
Sergey Edunov's avatar
Sergey Edunov committed
489

Myle Ott's avatar
Myle Ott committed
490
491
492
            if ignore_results:
                logging_output, sample_size = {}, 0

493
        # gather logging outputs from all replicas
Sergey Edunov's avatar
Sergey Edunov committed
494
        if self.args.distributed_world_size > 1:
495
496
            logging_output, sample_size = zip(*distributed_utils.all_gather_list(
                [logging_output, sample_size],
Sergey Edunov's avatar
Sergey Edunov committed
497
            ))
498
499
            logging_output = list(logging_output)
            sample_size = list(sample_size)
Sergey Edunov's avatar
Sergey Edunov committed
500
        else:
501
502
            logging_output = [logging_output]
            sample_size = [sample_size]
Myle Ott's avatar
Myle Ott committed
503

504
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
505
        logging_output = self.task.aggregate_logging_outputs(
Jeff Cai's avatar
Jeff Cai committed
506
            logging_output, self.get_criterion()
Peng-Jen Chen's avatar
Peng-Jen Chen committed
507
508
        )
        sample_size = self.task.grad_denom(
Jeff Cai's avatar
Jeff Cai committed
509
            sample_size, self.get_criterion()
Peng-Jen Chen's avatar
Peng-Jen Chen committed
510
        )
Myle Ott's avatar
Myle Ott committed
511

512
513
514
        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)
        self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
515
516
517
518
        if 'valid_acc' in self.meters:
            self.meters['valid_acc'].update(
                logging_output.get('acc', 0), sample_size)

519
520
        if 'nll_loss' in logging_output:
            self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
521

522
        return logging_output
Myle Ott's avatar
Myle Ott committed
523

Myle Ott's avatar
Myle Ott committed
524
525
    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
526
        self.train_step(dummy_batch, dummy_batch=True)
Myle Ott's avatar
Myle Ott committed
527
528
        self.zero_grad()

529
530
531
532
533
534
535
536
537
    def handle_ooms(self, number_of_ooms):
        """
        c10d accumulates/syncs gradients between gpus during backward pass.
        In case of OOMs, gpus may fail to sync, so we manually iterate
        extra to make sure each gpu makes same number of iterations.
        """
        for _ in range(number_of_ooms):
            self.train_step([self._oom_batch], True)

Myle Ott's avatar
Myle Ott committed
538
539
540
    def zero_grad(self):
        self.optimizer.zero_grad()

541
542
543
    def clear_buffered_stats(self):
        self._all_reduce_list = [0.0] * 6

Myle Ott's avatar
Myle Ott committed
544
545
    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
Nayan Singhal's avatar
Nayan Singhal committed
546
        self.lr_scheduler.step(epoch, val_loss)
Myle Ott's avatar
Myle Ott committed
547
548
        # prefer updating the LR based on the number of steps
        return self.lr_step_update()
Myle Ott's avatar
Myle Ott committed
549

Myle Ott's avatar
Myle Ott committed
550
    def lr_step_update(self):
Myle Ott's avatar
Myle Ott committed
551
        """Update the learning rate after each update."""
Myle Ott's avatar
Myle Ott committed
552
        return self.lr_scheduler.step_update(self.get_num_updates())
Myle Ott's avatar
Myle Ott committed
553

Myle Ott's avatar
Myle Ott committed
554
555
556
557
558
    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
559
560
        """Get the (non-wrapped) model instance."""
        return self._model
Myle Ott's avatar
Myle Ott committed
561

Jeff Cai's avatar
Jeff Cai committed
562
563
564
565
    def get_criterion(self):
        """Get the (non-wrapped) criterion instance."""
        return self._criterion

Myle Ott's avatar
Myle Ott committed
566
567
568
569
570
571
572
573
574
575
    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

Myle Ott's avatar
Myle Ott committed
576
577
578
579
580
    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        self._num_updates = num_updates
        self.lr_step_update()

Myle Ott's avatar
Myle Ott committed
581
    def _prepare_sample(self, sample):
Myle Ott's avatar
Myle Ott committed
582
583
        if sample is None or len(sample) == 0:
            return None
alexeib's avatar
alexeib committed
584

Myle Ott's avatar
Myle Ott committed
585
586
        if self.cuda:
            sample = utils.move_to_cuda(sample)
alexeib's avatar
alexeib committed
587
588
589
590
591
592

        def apply_half(t):
            if t.dtype is torch.float32:
                return t.half()
            return t

Myle Ott's avatar
Myle Ott committed
593
594
595
596
        if self.args.fp16:
            sample = utils.apply_to_sample(apply_half, sample)

        return sample
597
598
599
600
601
602
603
604

    def _set_seed(self):
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        if self.cuda:
            torch.cuda.manual_seed(seed)
605
606
607
608
609
610
611
612
613
614
615
616

    def _sync_stats(self):
        return (
            self.args.distributed_world_size > 1 and
            (
                (not self.args.use_bmuf) or
                (
                    self.args.use_bmuf
                    and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
                )
            )
        )
Jerry Ma's avatar
Jerry Ma committed
617
618
619
620
621
622
623
624
625
626
627
628
629

    def _log_oom(self, exc):
        msg = '| OOM: Ran out of memory with exception: {}'.format(exc)
        # TODO: print should really go to logger, this print goes
        # to stderr, which is buffered, which in many cases is not
        # printed out if another exception happens.
        # NB(jerry): added a flush to mitigate this
        print(msg, file=sys.stderr)
        if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
            for device_idx in range(torch.cuda.device_count()):
                print(torch.cuda.memory_summary(device=device_idx),
                      file=sys.stderr)
        sys.stderr.flush()