trainer.py 23.7 KB
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates.
Myle Ott's avatar
Myle Ott committed
2
#
3
4
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
5
6

"""
Myle Ott's avatar
Myle Ott committed
7
Train a network across multiple GPUs.
Myle Ott's avatar
Myle Ott committed
8
9
"""

Peng-Jen Chen's avatar
Peng-Jen Chen committed
10
from collections import OrderedDict
Myle Ott's avatar
Myle Ott committed
11
import contextlib
Sergey Edunov's avatar
Sergey Edunov committed
12
from itertools import chain
Myle Ott's avatar
Myle Ott committed
13
import math
Myle Ott's avatar
Myle Ott committed
14
import os
15
import sys
Myle Ott's avatar
Myle Ott committed
16

Myle Ott's avatar
Myle Ott committed
17
18
import torch

Myle Ott's avatar
Myle Ott committed
19
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
Myle Ott's avatar
Myle Ott committed
20
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
Myle Ott's avatar
Myle Ott committed
21
22
23
24
from fairseq.optim import lr_scheduler


class Trainer(object):
Myle Ott's avatar
Myle Ott committed
25
    """Main class for data parallel training.
Myle Ott's avatar
Myle Ott committed
26

27
28
29
30
31
    This class supports synchronous distributed data parallel training,
    where multiple workers each have a full model replica and gradients
    are accumulated across workers before each update. We use
    :class:`~torch.nn.parallel.DistributedDataParallel` to handle
    communication of the gradients across workers.
Myle Ott's avatar
Myle Ott committed
32
33
    """

Myle Ott's avatar
Myle Ott committed
34
    def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=None):
Myle Ott's avatar
Myle Ott committed
35
        self.args = args
36
        self.task = task
Myle Ott's avatar
Myle Ott committed
37
38

        # copy model and criterion to current device
Jeff Cai's avatar
Jeff Cai committed
39
        self._criterion = criterion
Myle Ott's avatar
Myle Ott committed
40
41
        self._model = model
        self.cuda = torch.cuda.is_available() and not args.cpu
42
        if args.fp16:
Jeff Cai's avatar
Jeff Cai committed
43
            self._criterion = self._criterion.half()
Myle Ott's avatar
Myle Ott committed
44
45
            self._model = self._model.half()
        if self.cuda:
Jeff Cai's avatar
Jeff Cai committed
46
            self._criterion = self._criterion.cuda()
Myle Ott's avatar
Myle Ott committed
47
            self._model = self._model.cuda()
Myle Ott's avatar
Myle Ott committed
48

49
        self._dummy_batch = dummy_batch
Myle Ott's avatar
Myle Ott committed
50
        self._oom_batch = oom_batch or dummy_batch
Myle Ott's avatar
Myle Ott committed
51
52

        self._lr_scheduler = None
53
54
55
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
56
        self._prev_grad_norm = None
Jeff Cai's avatar
Jeff Cai committed
57
        self._wrapped_criterion = None
58
59
        self._wrapped_model = None

60
61
62
63
64
        # Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
        # It is less flexible and syncs only the default stats.
        self._all_reduce_list = [0.0] * 6
        self.fast_stat_sync = args.fast_stat_sync

65
66
67
        self.init_meters(args)

    def init_meters(self, args):
Myle Ott's avatar
Myle Ott committed
68
69
70
71
72
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
Myle Ott's avatar
Myle Ott committed
73
74
75
76
77
78
79
        self.meters['wps'] = TimeMeter()       # words per second
        self.meters['ups'] = TimeMeter()       # updates per second
        self.meters['wpb'] = AverageMeter()    # words per batch
        self.meters['bsz'] = AverageMeter()    # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()   # % of updates clipped
        self.meters['oom'] = AverageMeter()    # out of memory
80
81
        if args.fp16:
            self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
Myle Ott's avatar
Myle Ott committed
82
        self.meters['wall'] = TimeMeter()      # wall time in seconds
Myle Ott's avatar
Myle Ott committed
83
        self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
Myle Ott's avatar
Myle Ott committed
84

Jeff Cai's avatar
Jeff Cai committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    @property
    def criterion(self):
        if self._wrapped_criterion is None:
            if (
                utils.has_parameters(self._criterion)
                and self.args.distributed_world_size > 1
                and not self.args.use_bmuf
            ):
                self._wrapped_criterion = models.DistributedFairseqModel(
                    self.args, self._criterion
                )
            else:
                self._wrapped_criterion = self._criterion
        return self._wrapped_criterion

100
101
102
    @property
    def model(self):
        if self._wrapped_model is None:
Nayan Singhal's avatar
Nayan Singhal committed
103
            if self.args.distributed_world_size > 1 and not self.args.use_bmuf:
104
105
106
107
108
109
                self._wrapped_model = models.DistributedFairseqModel(
                    self.args, self._model,
                )
            else:
                self._wrapped_model = self._model
        return self._wrapped_model
Myle Ott's avatar
Myle Ott committed
110
111
112
113
114
115

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer
Myle Ott's avatar
Myle Ott committed
116

Myle Ott's avatar
Myle Ott committed
117
118
119
    @property
    def lr_scheduler(self):
        if self._lr_scheduler is None:
Myle Ott's avatar
Myle Ott committed
120
            self._build_optimizer()  # this will initialize self._lr_scheduler
Myle Ott's avatar
Myle Ott committed
121
122
        return self._lr_scheduler

Myle Ott's avatar
Myle Ott committed
123
    def _build_optimizer(self):
Jeff Cai's avatar
Jeff Cai committed
124
125
126
127
128
129
130
        params = list(
            filter(
                lambda p: p.requires_grad,
                chain(self.model.parameters(), self.criterion.parameters()),
            )
        )

131
        if self.args.fp16:
Myle Ott's avatar
Myle Ott committed
132
            if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
133
134
                print('| WARNING: your device does NOT support faster training with --fp16, '
                      'please switch to FP32 which is likely to be faster')
Myle Ott's avatar
Myle Ott committed
135
136
137
138
            if self.args.memory_efficient_fp16:
                self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
            else:
                self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
139
        else:
Myle Ott's avatar
Myle Ott committed
140
            if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
141
                print('| NOTICE: your device may support faster training with --fp16')
Myle Ott's avatar
Myle Ott committed
142
            self._optimizer = optim.build_optimizer(self.args, params)
Myle Ott's avatar
Myle Ott committed
143

Nayan Singhal's avatar
Nayan Singhal committed
144
        if self.args.use_bmuf:
Jeff Cai's avatar
Jeff Cai committed
145
            self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)
Nayan Singhal's avatar
Nayan Singhal committed
146

Myle Ott's avatar
Myle Ott committed
147
148
149
        # We should initialize the learning rate scheduler immediately after
        # building the optimizer, so that the initial learning rate is set.
        self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
150
        self._lr_scheduler.step_update(0)
Myle Ott's avatar
Myle Ott committed
151

Myle Ott's avatar
Myle Ott committed
152
153
    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
154
        if distributed_utils.is_master(self.args):  # only save one checkpoint
Myle Ott's avatar
Myle Ott committed
155
            extra_state['train_meters'] = self.meters
Myle Ott's avatar
Myle Ott committed
156
            checkpoint_utils.save_state(
Jeff Cai's avatar
Jeff Cai committed
157
                filename, self.args, self.get_model().state_dict(), self.get_criterion(),
Myle Ott's avatar
Myle Ott committed
158
                self.optimizer, self.lr_scheduler, self.get_num_updates(),
Myle Ott's avatar
Myle Ott committed
159
                self._optim_history, extra_state,
Myle Ott's avatar
Nits  
Myle Ott committed
160
            )
Myle Ott's avatar
Myle Ott committed
161

Myle Ott's avatar
Myle Ott committed
162
163
164
165
166
167
168
169
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
Myle Ott's avatar
Myle Ott committed
170
        """Load all training state from a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
171
172
173
174
175
176
177
178
        extra_state, self._optim_history, last_optim_state = None, [], None

        if os.path.exists(filename):
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(state['model'], strict=True)
Jeff Cai's avatar
Jeff Cai committed
179
180
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state['criterion'], strict=True)
Myle Ott's avatar
Myle Ott committed
181
182
            except Exception:
                raise Exception(
183
184
                    'Cannot load model parameters from checkpoint {}; '
                    'please ensure that the architectures match.'.format(filename)
Myle Ott's avatar
Myle Ott committed
185
186
187
188
                )

            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
Myle Ott's avatar
Myle Ott committed
189
            last_optim_state = state.get('last_optimizer_state', None)
Myle Ott's avatar
Myle Ott committed
190

191
        if last_optim_state is not None and not reset_optimizer:
Myle Ott's avatar
Myle Ott committed
192
            # rebuild optimizer after loading model, since params may have changed
Myle Ott's avatar
Myle Ott committed
193
            self._build_optimizer()
Myle Ott's avatar
Myle Ott committed
194

195
196
            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
Jeff Cai's avatar
Jeff Cai committed
197
            assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
198
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
199
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
200
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
201
202
203
204
205

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

Myle Ott's avatar
Myle Ott committed
206
            self.set_num_updates(last_optim['num_updates'])
Myle Ott's avatar
Myle Ott committed
207

Myle Ott's avatar
Myle Ott committed
208
209
210
211
        if extra_state is not None:
            epoch = extra_state['train_iterator']['epoch']
            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                filename, epoch, self.get_num_updates()))
Myle Ott's avatar
Myle Ott committed
212

Myle Ott's avatar
Myle Ott committed
213
214
            self.lr_step(epoch)

Myle Ott's avatar
Myle Ott committed
215
            if 'train_meters' in extra_state and not reset_meters:
Myle Ott's avatar
Myle Ott committed
216
217
218
219
220
221
222
223
224
                self.meters.update(extra_state['train_meters'])
                del extra_state['train_meters']

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in self.meters.values():
                    if isinstance(meter, TimeMeter):
                        meter.reset()
        else:
            print('| no existing checkpoint found {}'.format(filename))
Myle Ott's avatar
Myle Ott committed
225

Myle Ott's avatar
Myle Ott committed
226
227
        return extra_state

228
    def get_train_iterator(self, epoch, combine=True, load_dataset=True):
Myle Ott's avatar
Myle Ott committed
229
        """Return an EpochBatchIterator over the training set for a given epoch."""
230
231
232
        if load_dataset:
            print('| loading train data for epoch {}'.format(epoch))
            self.task.load_dataset(self.args.train_subset, epoch=epoch, combine=combine)
Myle Ott's avatar
Myle Ott committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        return self.task.get_batch_iterator(
            dataset=self.task.dataset(self.args.train_subset),
            max_tokens=self.args.max_tokens,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(
                self.task.max_positions(),
                self.model.max_positions(),
            ),
            ignore_invalid_inputs=True,
            required_batch_size_multiple=self.args.required_batch_size_multiple,
            seed=self.args.seed,
            num_shards=self.args.distributed_world_size,
            shard_id=self.args.distributed_rank,
            num_workers=self.args.num_workers,
            epoch=epoch,
        )

250
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
251
        """Do forward, backward and parameter update."""
Myle Ott's avatar
Myle Ott committed
252
253
254
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

255
        self._set_seed()
256
        self.model.train()
Myle Ott's avatar
Myle Ott committed
257
        self.criterion.train()
258
259
        self.zero_grad()

Myle Ott's avatar
Myle Ott committed
260
261
262
        if not dummy_batch:
            self.meters['train_wall'].start()

Sergey Edunov's avatar
Sergey Edunov committed
263
        # forward and backward pass
264
265
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
Myle Ott's avatar
Myle Ott committed
266
            sample = self._prepare_sample(sample)
267
268
269
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
Myle Ott's avatar
Myle Ott committed
270
                sample = self._prepare_sample(self._dummy_batch)
271
272
273
                ignore_grad = True
            else:
                ignore_grad = False
Myle Ott's avatar
Myle Ott committed
274

Myle Ott's avatar
Myle Ott committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.args.distributed_world_size > 1
                    and hasattr(self.model, 'no_sync')
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

290
            try:
Myle Ott's avatar
Myle Ott committed
291
292
293
294
295
296
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad
                    )
297
298
299
300

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
301
302
303
304
305
306
307

                    if self.fast_stat_sync:
                        self._all_reduce_list[0] += sample_size
                        self._all_reduce_list[1] += logging_output.get('nsentences', 0.0)
                        self._all_reduce_list[2] += logging_output.get('loss', 0.0)
                        self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0)
                        self._all_reduce_list[4] += logging_output.get('ntokens', 0.0)
308
309
            except RuntimeError as e:
                if 'out of memory' in str(e):
310
311
312
313
314
315
                    msg = (
                        '| WARNING: ran out of memory with exception: '
                        + '{};'.format(e)
                        + '\n Skipping batch'
                    )
                    # TODO: print should really go to logger, this print goes
316
317
318
                    # to stderr, which is buffered, which in many cases is not
                    # printed out if another exception happens.
                    # NB(jerry): added a flush to mitigate this
319
                    print(msg, file=sys.stderr)
320
321
322
323
324
325
                    if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
                        for device_idx in range(torch.cuda.device_count()):
                            print(torch.cuda.memory_summary(device=torch.cuda.device(device_idx)),
                                  file=sys.stderr)
                    sys.stderr.flush()

326
327
                    if raise_oom:
                        raise ValueError(msg)
328
329
330
331
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e
Myle Ott's avatar
Myle Ott committed
332

333
334
335
336
            if self.fast_stat_sync:
                self._all_reduce_list[5] += ooms


337
338
339
        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

340
341
        if dummy_batch:
            return None
Myle Ott's avatar
Myle Ott committed
342
343

        # gather logging outputs from all replicas
344
345
346
347
348
349
350
351
352
353
354
355
        if self.fast_stat_sync:
            # rework all_gather_list
            all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
            if self._sync_stats():
                torch.distributed.all_reduce(all_reduce_list_tensor)
            # Normalize loss and nll_loss by "sample_size"
            # and convert to log base 2
            all_reduce_list_tensor[2:4].div_(
                (
                    all_reduce_list_tensor[0:1] *
                    torch.log(torch.cuda.DoubleTensor([2]))
                )
Nayan Singhal's avatar
Nayan Singhal committed
356
            )
357
358
359
360
361
362
363
364
365
366
367
            self._all_reduce_list = all_reduce_list_tensor.tolist()
            logging_output = {}
            [
                sample_size,
                logging_output['nsentences'],
                logging_output['loss'],
                logging_output['nll_loss'],
                logging_output['ntokens'],
                ooms,
            ] = self._all_reduce_list
        elif self._sync_stats():
368
369
370
371
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
372
373
374
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)
Nayan Singhal's avatar
Nayan Singhal committed
375
376
377
378
379
380

            if not self.args.use_bmuf:
                assert (
                    all(norm == prev_norms[0] for norm in prev_norms)
                    or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
                ), 'Fatal error: gradients are inconsistent between workers'
Myle Ott's avatar
Myle Ott committed
381

382
        self.meters['oom'].update(ooms, len(samples))
383
        if ooms == self.args.distributed_world_size * len(samples):
384
            print('| WARNING: OOM in all workers, skipping update')
Myle Ott's avatar
Myle Ott committed
385
386
387
            self.zero_grad()
            return None

388
389
390
391
392
393
        if not self.fast_stat_sync:
            # aggregate logging outputs and sample sizes
            logging_output = self.task.aggregate_logging_outputs(
                logging_outputs, self.get_criterion()
            )
            sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
394
395
396
397
398

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
Peng-Jen Chen's avatar
Peng-Jen Chen committed
399
            ).format(self.task.__class__.__name__))
Myle Ott's avatar
Myle Ott committed
400
401

        try:
402
            # normalize grads by sample size
Nayan Singhal's avatar
Nayan Singhal committed
403
404
            if sample_size > 0:
                self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
405
406
407

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
408
            self._prev_grad_norm = grad_norm
409
410
411

            # take an optimization step
            self.optimizer.step()
Myle Ott's avatar
Myle Ott committed
412
            self.set_num_updates(self.get_num_updates() + 1)
Myle Ott's avatar
Myle Ott committed
413

414
415
416
            # task specific update per step
            self.task.update_step(self._num_updates)

Myle Ott's avatar
Myle Ott committed
417
            # update meters
418
419
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
Myle Ott's avatar
Myle Ott committed
420
421
422
423
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
424
425
426
427
428
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
429
430
431
432
            if 'train_acc' in self.meters:
                self.meters['train_acc'].update(
                    logging_output.get('acc', 0), sample_size)

433
434
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
435
436
437
438
439
440
441
442

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                    ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                     self.args.empty_cache_freq) == 0 and
                    torch.cuda.is_available() and
                    not self.args.cpu):
                torch.cuda.empty_cache()
Myle Ott's avatar
Myle Ott committed
443
444
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
445
446
            self.zero_grad()
            logging_output = None
Myle Ott's avatar
Myle Ott committed
447

448
449
450
        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
Myle Ott's avatar
Myle Ott committed
451

452
        self.clear_buffered_stats()
453
        self.meters['train_wall'].stop()
Myle Ott's avatar
Myle Ott committed
454

455
        return logging_output
Myle Ott's avatar
Myle Ott committed
456

457
    def valid_step(self, sample, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
458
        """Do forward pass in evaluation mode."""
459
        with torch.no_grad():
Myle Ott's avatar
Myle Ott committed
460
            self.model.eval()
Myle Ott's avatar
Myle Ott committed
461
            self.criterion.eval()
Myle Ott's avatar
Myle Ott committed
462

Myle Ott's avatar
Myle Ott committed
463
            sample = self._prepare_sample(sample)
464
            if sample is None:
Myle Ott's avatar
Myle Ott committed
465
                sample = self._prepare_sample(self._dummy_batch)
Myle Ott's avatar
Myle Ott committed
466
467
468
469
                ignore_results = True
            else:
                ignore_results = False

470
            try:
Peng-Jen Chen's avatar
Peng-Jen Chen committed
471
472
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion
473
474
475
476
477
478
                )
            except RuntimeError as e:
                if 'out of memory' in str(e) and not raise_oom:
                    print('| WARNING: ran out of memory, retrying batch')
                    for p in self.model.parameters():
                        if p.grad is not None:
Myle Ott's avatar
Myle Ott committed
479
                            p.grad = None  # free some memory
Myle Ott's avatar
Myle Ott committed
480
481
                    if self.cuda:
                        torch.cuda.empty_cache()
482
483
484
                    return self.valid_step(sample, raise_oom=True)
                else:
                    raise e
Sergey Edunov's avatar
Sergey Edunov committed
485

Myle Ott's avatar
Myle Ott committed
486
487
488
            if ignore_results:
                logging_output, sample_size = {}, 0

489
        # gather logging outputs from all replicas
Sergey Edunov's avatar
Sergey Edunov committed
490
        if self.args.distributed_world_size > 1:
491
492
            logging_output, sample_size = zip(*distributed_utils.all_gather_list(
                [logging_output, sample_size],
Sergey Edunov's avatar
Sergey Edunov committed
493
            ))
494
495
            logging_output = list(logging_output)
            sample_size = list(sample_size)
Sergey Edunov's avatar
Sergey Edunov committed
496
        else:
497
498
            logging_output = [logging_output]
            sample_size = [sample_size]
Myle Ott's avatar
Myle Ott committed
499

500
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
501
        logging_output = self.task.aggregate_logging_outputs(
Jeff Cai's avatar
Jeff Cai committed
502
            logging_output, self.get_criterion()
Peng-Jen Chen's avatar
Peng-Jen Chen committed
503
504
        )
        sample_size = self.task.grad_denom(
Jeff Cai's avatar
Jeff Cai committed
505
            sample_size, self.get_criterion()
Peng-Jen Chen's avatar
Peng-Jen Chen committed
506
        )
Myle Ott's avatar
Myle Ott committed
507

508
509
510
        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)
        self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
511
512
513
514
        if 'valid_acc' in self.meters:
            self.meters['valid_acc'].update(
                logging_output.get('acc', 0), sample_size)

515
516
        if 'nll_loss' in logging_output:
            self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
517

518
        return logging_output
Myle Ott's avatar
Myle Ott committed
519

Myle Ott's avatar
Myle Ott committed
520
521
    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
522
        self.train_step(dummy_batch, dummy_batch=True)
Myle Ott's avatar
Myle Ott committed
523
524
        self.zero_grad()

525
526
527
528
529
530
531
532
533
    def handle_ooms(self, number_of_ooms):
        """
        c10d accumulates/syncs gradients between gpus during backward pass.
        In case of OOMs, gpus may fail to sync, so we manually iterate
        extra to make sure each gpu makes same number of iterations.
        """
        for _ in range(number_of_ooms):
            self.train_step([self._oom_batch], True)

Myle Ott's avatar
Myle Ott committed
534
535
536
    def zero_grad(self):
        self.optimizer.zero_grad()

537
538
539
    def clear_buffered_stats(self):
        self._all_reduce_list = [0.0] * 6

Myle Ott's avatar
Myle Ott committed
540
541
    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
Nayan Singhal's avatar
Nayan Singhal committed
542
        self.lr_scheduler.step(epoch, val_loss)
Myle Ott's avatar
Myle Ott committed
543
544
        # prefer updating the LR based on the number of steps
        return self.lr_step_update()
Myle Ott's avatar
Myle Ott committed
545

Myle Ott's avatar
Myle Ott committed
546
    def lr_step_update(self):
Myle Ott's avatar
Myle Ott committed
547
        """Update the learning rate after each update."""
Myle Ott's avatar
Myle Ott committed
548
        return self.lr_scheduler.step_update(self.get_num_updates())
Myle Ott's avatar
Myle Ott committed
549

Myle Ott's avatar
Myle Ott committed
550
551
552
553
554
    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
555
556
        """Get the (non-wrapped) model instance."""
        return self._model
Myle Ott's avatar
Myle Ott committed
557

Jeff Cai's avatar
Jeff Cai committed
558
559
560
561
    def get_criterion(self):
        """Get the (non-wrapped) criterion instance."""
        return self._criterion

Myle Ott's avatar
Myle Ott committed
562
563
564
565
566
567
568
569
570
571
    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

Myle Ott's avatar
Myle Ott committed
572
573
574
575
576
    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        self._num_updates = num_updates
        self.lr_step_update()

Myle Ott's avatar
Myle Ott committed
577
    def _prepare_sample(self, sample):
Myle Ott's avatar
Myle Ott committed
578
579
        if sample is None or len(sample) == 0:
            return None
alexeib's avatar
alexeib committed
580

Myle Ott's avatar
Myle Ott committed
581
582
        if self.cuda:
            sample = utils.move_to_cuda(sample)
alexeib's avatar
alexeib committed
583
584
585
586
587
588

        def apply_half(t):
            if t.dtype is torch.float32:
                return t.half()
            return t

Myle Ott's avatar
Myle Ott committed
589
590
591
592
        if self.args.fp16:
            sample = utils.apply_to_sample(apply_half, sample)

        return sample
593
594
595
596
597
598
599
600

    def _set_seed(self):
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        if self.cuda:
            torch.cuda.manual_seed(seed)
601
602
603
604
605
606
607
608
609
610
611
612

    def _sync_stats(self):
        return (
            self.args.distributed_world_size > 1 and
            (
                (not self.args.use_bmuf) or
                (
                    self.args.use_bmuf
                    and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
                )
            )
        )