trainer.py 20.9 KB
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates.
Myle Ott's avatar
Myle Ott committed
2
#
3
4
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
5
6

"""
Myle Ott's avatar
Myle Ott committed
7
Train a network across multiple GPUs.
Myle Ott's avatar
Myle Ott committed
8
9
"""

Peng-Jen Chen's avatar
Peng-Jen Chen committed
10
from collections import OrderedDict
Myle Ott's avatar
Myle Ott committed
11
import contextlib
Sergey Edunov's avatar
Sergey Edunov committed
12
from itertools import chain
Myle Ott's avatar
Myle Ott committed
13
import math
Myle Ott's avatar
Myle Ott committed
14
import os
15
import sys
Myle Ott's avatar
Myle Ott committed
16

Myle Ott's avatar
Myle Ott committed
17
18
import torch

Myle Ott's avatar
Myle Ott committed
19
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
Myle Ott's avatar
Myle Ott committed
20
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
Myle Ott's avatar
Myle Ott committed
21
22
23
24
from fairseq.optim import lr_scheduler


class Trainer(object):
Myle Ott's avatar
Myle Ott committed
25
    """Main class for data parallel training.
Myle Ott's avatar
Myle Ott committed
26

27
28
29
30
31
    This class supports synchronous distributed data parallel training,
    where multiple workers each have a full model replica and gradients
    are accumulated across workers before each update. We use
    :class:`~torch.nn.parallel.DistributedDataParallel` to handle
    communication of the gradients across workers.
Myle Ott's avatar
Myle Ott committed
32
33
    """

Myle Ott's avatar
Myle Ott committed
34
    def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=None):
Myle Ott's avatar
Myle Ott committed
35
        self.args = args
36
        self.task = task
Myle Ott's avatar
Myle Ott committed
37
38

        # copy model and criterion to current device
Jeff Cai's avatar
Jeff Cai committed
39
        self._criterion = criterion
Myle Ott's avatar
Myle Ott committed
40
41
        self._model = model
        self.cuda = torch.cuda.is_available() and not args.cpu
42
        if args.fp16:
Jeff Cai's avatar
Jeff Cai committed
43
            self._criterion = self._criterion.half()
Myle Ott's avatar
Myle Ott committed
44
45
            self._model = self._model.half()
        if self.cuda:
Jeff Cai's avatar
Jeff Cai committed
46
            self._criterion = self._criterion.cuda()
Myle Ott's avatar
Myle Ott committed
47
            self._model = self._model.cuda()
Myle Ott's avatar
Myle Ott committed
48

49
        self._dummy_batch = dummy_batch
Myle Ott's avatar
Myle Ott committed
50
        self._oom_batch = oom_batch or dummy_batch
Myle Ott's avatar
Myle Ott committed
51
52

        self._lr_scheduler = None
53
54
55
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
56
        self._prev_grad_norm = None
Jeff Cai's avatar
Jeff Cai committed
57
        self._wrapped_criterion = None
58
59
60
61
62
        self._wrapped_model = None

        self.init_meters(args)

    def init_meters(self, args):
Myle Ott's avatar
Myle Ott committed
63
64
65
66
67
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
Myle Ott's avatar
Myle Ott committed
68
69
70
71
72
73
74
        self.meters['wps'] = TimeMeter()       # words per second
        self.meters['ups'] = TimeMeter()       # updates per second
        self.meters['wpb'] = AverageMeter()    # words per batch
        self.meters['bsz'] = AverageMeter()    # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()   # % of updates clipped
        self.meters['oom'] = AverageMeter()    # out of memory
75
76
        if args.fp16:
            self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
Myle Ott's avatar
Myle Ott committed
77
        self.meters['wall'] = TimeMeter()      # wall time in seconds
Myle Ott's avatar
Myle Ott committed
78
        self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
Myle Ott's avatar
Myle Ott committed
79

Jeff Cai's avatar
Jeff Cai committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    @property
    def criterion(self):
        if self._wrapped_criterion is None:
            if (
                utils.has_parameters(self._criterion)
                and self.args.distributed_world_size > 1
                and not self.args.use_bmuf
            ):
                self._wrapped_criterion = models.DistributedFairseqModel(
                    self.args, self._criterion
                )
            else:
                self._wrapped_criterion = self._criterion
        return self._wrapped_criterion

95
96
97
    @property
    def model(self):
        if self._wrapped_model is None:
Nayan Singhal's avatar
Nayan Singhal committed
98
            if self.args.distributed_world_size > 1 and not self.args.use_bmuf:
99
100
101
102
103
104
                self._wrapped_model = models.DistributedFairseqModel(
                    self.args, self._model,
                )
            else:
                self._wrapped_model = self._model
        return self._wrapped_model
Myle Ott's avatar
Myle Ott committed
105
106
107
108
109
110

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer
Myle Ott's avatar
Myle Ott committed
111

Myle Ott's avatar
Myle Ott committed
112
113
114
    @property
    def lr_scheduler(self):
        if self._lr_scheduler is None:
Myle Ott's avatar
Myle Ott committed
115
            self._build_optimizer()  # this will initialize self._lr_scheduler
Myle Ott's avatar
Myle Ott committed
116
117
        return self._lr_scheduler

Myle Ott's avatar
Myle Ott committed
118
    def _build_optimizer(self):
Jeff Cai's avatar
Jeff Cai committed
119
120
121
122
123
124
125
        params = list(
            filter(
                lambda p: p.requires_grad,
                chain(self.model.parameters(), self.criterion.parameters()),
            )
        )

126
        if self.args.fp16:
Myle Ott's avatar
Myle Ott committed
127
            if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
128
129
                print('| WARNING: your device does NOT support faster training with --fp16, '
                      'please switch to FP32 which is likely to be faster')
Myle Ott's avatar
Myle Ott committed
130
131
132
133
            if self.args.memory_efficient_fp16:
                self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
            else:
                self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
134
        else:
Myle Ott's avatar
Myle Ott committed
135
            if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
136
                print('| NOTICE: your device may support faster training with --fp16')
Myle Ott's avatar
Myle Ott committed
137
            self._optimizer = optim.build_optimizer(self.args, params)
Myle Ott's avatar
Myle Ott committed
138

Nayan Singhal's avatar
Nayan Singhal committed
139
        if self.args.use_bmuf:
Jeff Cai's avatar
Jeff Cai committed
140
            self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)
Nayan Singhal's avatar
Nayan Singhal committed
141

Myle Ott's avatar
Myle Ott committed
142
143
144
        # We should initialize the learning rate scheduler immediately after
        # building the optimizer, so that the initial learning rate is set.
        self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
145
        self._lr_scheduler.step_update(0)
Myle Ott's avatar
Myle Ott committed
146

Myle Ott's avatar
Myle Ott committed
147
148
    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
149
        if distributed_utils.is_master(self.args):  # only save one checkpoint
Myle Ott's avatar
Myle Ott committed
150
            extra_state['train_meters'] = self.meters
Myle Ott's avatar
Myle Ott committed
151
            checkpoint_utils.save_state(
Jeff Cai's avatar
Jeff Cai committed
152
                filename, self.args, self.get_model().state_dict(), self.get_criterion(),
Myle Ott's avatar
Myle Ott committed
153
                self.optimizer, self.lr_scheduler, self.get_num_updates(),
Myle Ott's avatar
Myle Ott committed
154
                self._optim_history, extra_state,
Myle Ott's avatar
Nits  
Myle Ott committed
155
            )
Myle Ott's avatar
Myle Ott committed
156

Myle Ott's avatar
Myle Ott committed
157
158
159
160
161
162
163
164
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
Myle Ott's avatar
Myle Ott committed
165
        """Load all training state from a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
166
167
168
169
170
171
172
173
        extra_state, self._optim_history, last_optim_state = None, [], None

        if os.path.exists(filename):
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(state['model'], strict=True)
Jeff Cai's avatar
Jeff Cai committed
174
175
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state['criterion'], strict=True)
Myle Ott's avatar
Myle Ott committed
176
177
            except Exception:
                raise Exception(
178
179
                    'Cannot load model parameters from checkpoint {}; '
                    'please ensure that the architectures match.'.format(filename)
Myle Ott's avatar
Myle Ott committed
180
181
182
183
                )

            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
Myle Ott's avatar
Myle Ott committed
184
            last_optim_state = state.get('last_optimizer_state', None)
Myle Ott's avatar
Myle Ott committed
185

186
        if last_optim_state is not None and not reset_optimizer:
Myle Ott's avatar
Myle Ott committed
187
            # rebuild optimizer after loading model, since params may have changed
Myle Ott's avatar
Myle Ott committed
188
            self._build_optimizer()
Myle Ott's avatar
Myle Ott committed
189

190
191
            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
Jeff Cai's avatar
Jeff Cai committed
192
            assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
193
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
194
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
195
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
196
197
198
199
200

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

Myle Ott's avatar
Myle Ott committed
201
            self.set_num_updates(last_optim['num_updates'])
Myle Ott's avatar
Myle Ott committed
202

Myle Ott's avatar
Myle Ott committed
203
204
205
206
        if extra_state is not None:
            epoch = extra_state['train_iterator']['epoch']
            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                filename, epoch, self.get_num_updates()))
Myle Ott's avatar
Myle Ott committed
207

Myle Ott's avatar
Myle Ott committed
208
209
            self.lr_step(epoch)

Myle Ott's avatar
Myle Ott committed
210
            if 'train_meters' in extra_state and not reset_meters:
Myle Ott's avatar
Myle Ott committed
211
212
213
214
215
216
217
218
219
                self.meters.update(extra_state['train_meters'])
                del extra_state['train_meters']

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in self.meters.values():
                    if isinstance(meter, TimeMeter):
                        meter.reset()
        else:
            print('| no existing checkpoint found {}'.format(filename))
Myle Ott's avatar
Myle Ott committed
220

Myle Ott's avatar
Myle Ott committed
221
222
        return extra_state

Myle Ott's avatar
Myle Ott committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    def get_train_iterator(self, epoch, combine=True):
        """Return an EpochBatchIterator over the training set for a given epoch."""
        print('| loading train data for epoch {}'.format(epoch))
        self.task.load_dataset(self.args.train_subset, epoch=epoch, combine=combine)
        return self.task.get_batch_iterator(
            dataset=self.task.dataset(self.args.train_subset),
            max_tokens=self.args.max_tokens,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(
                self.task.max_positions(),
                self.model.max_positions(),
            ),
            ignore_invalid_inputs=True,
            required_batch_size_multiple=self.args.required_batch_size_multiple,
            seed=self.args.seed,
            num_shards=self.args.distributed_world_size,
            shard_id=self.args.distributed_rank,
            num_workers=self.args.num_workers,
            epoch=epoch,
        )

244
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
245
        """Do forward, backward and parameter update."""
Myle Ott's avatar
Myle Ott committed
246
247
248
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

249
        self._set_seed()
250
        self.model.train()
Myle Ott's avatar
Myle Ott committed
251
        self.criterion.train()
252
253
        self.zero_grad()

Myle Ott's avatar
Myle Ott committed
254
255
256
        if not dummy_batch:
            self.meters['train_wall'].start()

Sergey Edunov's avatar
Sergey Edunov committed
257
        # forward and backward pass
258
259
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
Myle Ott's avatar
Myle Ott committed
260
            sample = self._prepare_sample(sample)
261
262
263
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
Myle Ott's avatar
Myle Ott committed
264
                sample = self._prepare_sample(self._dummy_batch)
265
266
267
                ignore_grad = True
            else:
                ignore_grad = False
Myle Ott's avatar
Myle Ott committed
268

Myle Ott's avatar
Myle Ott committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.args.distributed_world_size > 1
                    and hasattr(self.model, 'no_sync')
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

284
            try:
Myle Ott's avatar
Myle Ott committed
285
286
287
288
289
290
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad
                    )
291
292
293
294
295
296

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
297
298
299
300
301
302
303
304
305
306
307
308
                    msg = (
                        '| WARNING: ran out of memory with exception: '
                        + '{};'.format(e)
                        + '\n Skipping batch'
                    )
                    # TODO: print should really go to logger, this print goes
                    # to stdout, which is buffered, which in many case is not
                    # printed out if another exception happens
                    # print(msg)
                    print(msg, file=sys.stderr)
                    if raise_oom:
                        raise ValueError(msg)
309
310
311
312
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e
Myle Ott's avatar
Myle Ott committed
313

314
315
316
        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

317
318
        if dummy_batch:
            return None
Myle Ott's avatar
Myle Ott committed
319
320

        # gather logging outputs from all replicas
Nayan Singhal's avatar
Nayan Singhal committed
321
322
323
324
325
326
327
        if self.args.distributed_world_size > 1 and (
            (not self.args.use_bmuf)
            or (
                self.args.use_bmuf
                and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
            )
        ):
328
329
330
331
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
332
333
334
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)
Nayan Singhal's avatar
Nayan Singhal committed
335
336
337
338
339
340

            if not self.args.use_bmuf:
                assert (
                    all(norm == prev_norms[0] for norm in prev_norms)
                    or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
                ), 'Fatal error: gradients are inconsistent between workers'
Myle Ott's avatar
Myle Ott committed
341

342
        self.meters['oom'].update(ooms, len(samples))
343
        if ooms == self.args.distributed_world_size * len(samples):
344
            print('| WARNING: OOM in all workers, skipping update')
Myle Ott's avatar
Myle Ott committed
345
346
347
            self.zero_grad()
            return None

348
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
349
        logging_output = self.task.aggregate_logging_outputs(
Jeff Cai's avatar
Jeff Cai committed
350
            logging_outputs, self.get_criterion()
Peng-Jen Chen's avatar
Peng-Jen Chen committed
351
        )
Jeff Cai's avatar
Jeff Cai committed
352
        sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
353
354
355
356
357

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
Peng-Jen Chen's avatar
Peng-Jen Chen committed
358
            ).format(self.task.__class__.__name__))
Myle Ott's avatar
Myle Ott committed
359
360

        try:
361
            # normalize grads by sample size
Nayan Singhal's avatar
Nayan Singhal committed
362
363
            if sample_size > 0:
                self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
364
365
366

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
367
            self._prev_grad_norm = grad_norm
368
369
370

            # take an optimization step
            self.optimizer.step()
Myle Ott's avatar
Myle Ott committed
371
            self.set_num_updates(self.get_num_updates() + 1)
Myle Ott's avatar
Myle Ott committed
372

373
374
375
            # task specific update per step
            self.task.update_step(self._num_updates)

Myle Ott's avatar
Myle Ott committed
376
            # update meters
377
378
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
Myle Ott's avatar
Myle Ott committed
379
380
381
382
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
383
384
385
386
387
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
388
389
390
391
            if 'train_acc' in self.meters:
                self.meters['train_acc'].update(
                    logging_output.get('acc', 0), sample_size)

392
393
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
394
395
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
396
397
            self.zero_grad()
            logging_output = None
Myle Ott's avatar
Myle Ott committed
398

399
400
401
        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
Myle Ott's avatar
Myle Ott committed
402

403
        self.meters['train_wall'].stop()
Myle Ott's avatar
Myle Ott committed
404

405
        return logging_output
Myle Ott's avatar
Myle Ott committed
406

407
    def valid_step(self, sample, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
408
        """Do forward pass in evaluation mode."""
409
        with torch.no_grad():
Myle Ott's avatar
Myle Ott committed
410
            self.model.eval()
Myle Ott's avatar
Myle Ott committed
411
            self.criterion.eval()
Myle Ott's avatar
Myle Ott committed
412

Myle Ott's avatar
Myle Ott committed
413
            sample = self._prepare_sample(sample)
414
            if sample is None:
Myle Ott's avatar
Myle Ott committed
415
                sample = self._prepare_sample(self._dummy_batch)
Myle Ott's avatar
Myle Ott committed
416
417
418
419
                ignore_results = True
            else:
                ignore_results = False

420
            try:
Peng-Jen Chen's avatar
Peng-Jen Chen committed
421
422
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion
423
424
425
426
427
428
                )
            except RuntimeError as e:
                if 'out of memory' in str(e) and not raise_oom:
                    print('| WARNING: ran out of memory, retrying batch')
                    for p in self.model.parameters():
                        if p.grad is not None:
Myle Ott's avatar
Myle Ott committed
429
                            p.grad = None  # free some memory
Myle Ott's avatar
Myle Ott committed
430
431
                    if self.cuda:
                        torch.cuda.empty_cache()
432
433
434
                    return self.valid_step(sample, raise_oom=True)
                else:
                    raise e
Sergey Edunov's avatar
Sergey Edunov committed
435

Myle Ott's avatar
Myle Ott committed
436
437
438
            if ignore_results:
                logging_output, sample_size = {}, 0

439
        # gather logging outputs from all replicas
Sergey Edunov's avatar
Sergey Edunov committed
440
        if self.args.distributed_world_size > 1:
441
442
            logging_output, sample_size = zip(*distributed_utils.all_gather_list(
                [logging_output, sample_size],
Sergey Edunov's avatar
Sergey Edunov committed
443
            ))
444
445
            logging_output = list(logging_output)
            sample_size = list(sample_size)
Sergey Edunov's avatar
Sergey Edunov committed
446
        else:
447
448
            logging_output = [logging_output]
            sample_size = [sample_size]
Myle Ott's avatar
Myle Ott committed
449

450
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
451
        logging_output = self.task.aggregate_logging_outputs(
Jeff Cai's avatar
Jeff Cai committed
452
            logging_output, self.get_criterion()
Peng-Jen Chen's avatar
Peng-Jen Chen committed
453
454
        )
        sample_size = self.task.grad_denom(
Jeff Cai's avatar
Jeff Cai committed
455
            sample_size, self.get_criterion()
Peng-Jen Chen's avatar
Peng-Jen Chen committed
456
        )
Myle Ott's avatar
Myle Ott committed
457

458
459
460
        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)
        self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
461
462
463
464
        if 'valid_acc' in self.meters:
            self.meters['valid_acc'].update(
                logging_output.get('acc', 0), sample_size)

465
466
        if 'nll_loss' in logging_output:
            self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
467

468
        return logging_output
Myle Ott's avatar
Myle Ott committed
469

Myle Ott's avatar
Myle Ott committed
470
471
    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
472
        self.train_step(dummy_batch, dummy_batch=True)
Myle Ott's avatar
Myle Ott committed
473
474
        self.zero_grad()

475
476
477
478
479
480
481
482
483
    def handle_ooms(self, number_of_ooms):
        """
        c10d accumulates/syncs gradients between gpus during backward pass.
        In case of OOMs, gpus may fail to sync, so we manually iterate
        extra to make sure each gpu makes same number of iterations.
        """
        for _ in range(number_of_ooms):
            self.train_step([self._oom_batch], True)

Myle Ott's avatar
Myle Ott committed
484
485
486
    def zero_grad(self):
        self.optimizer.zero_grad()

Myle Ott's avatar
Myle Ott committed
487
488
    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
Nayan Singhal's avatar
Nayan Singhal committed
489
        self.lr_scheduler.step(epoch, val_loss)
Myle Ott's avatar
Myle Ott committed
490
491
        # prefer updating the LR based on the number of steps
        return self.lr_step_update()
Myle Ott's avatar
Myle Ott committed
492

Myle Ott's avatar
Myle Ott committed
493
    def lr_step_update(self):
Myle Ott's avatar
Myle Ott committed
494
        """Update the learning rate after each update."""
Myle Ott's avatar
Myle Ott committed
495
        return self.lr_scheduler.step_update(self.get_num_updates())
Myle Ott's avatar
Myle Ott committed
496

Myle Ott's avatar
Myle Ott committed
497
498
499
500
501
    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
502
503
        """Get the (non-wrapped) model instance."""
        return self._model
Myle Ott's avatar
Myle Ott committed
504

Jeff Cai's avatar
Jeff Cai committed
505
506
507
508
    def get_criterion(self):
        """Get the (non-wrapped) criterion instance."""
        return self._criterion

Myle Ott's avatar
Myle Ott committed
509
510
511
512
513
514
515
516
517
518
    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

Myle Ott's avatar
Myle Ott committed
519
520
521
522
523
    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        self._num_updates = num_updates
        self.lr_step_update()

Myle Ott's avatar
Myle Ott committed
524
    def _prepare_sample(self, sample):
Myle Ott's avatar
Myle Ott committed
525
526
        if sample is None or len(sample) == 0:
            return None
alexeib's avatar
alexeib committed
527

Myle Ott's avatar
Myle Ott committed
528
529
        if self.cuda:
            sample = utils.move_to_cuda(sample)
alexeib's avatar
alexeib committed
530
531
532
533
534
535

        def apply_half(t):
            if t.dtype is torch.float32:
                return t.half()
            return t

Myle Ott's avatar
Myle Ott committed
536
537
538
539
        if self.args.fp16:
            sample = utils.apply_to_sample(apply_half, sample)

        return sample
540
541
542
543
544
545
546
547

    def _set_seed(self):
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        if self.cuda:
            torch.cuda.manual_seed(seed)