trainer.py 19.9 KB
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates.
Myle Ott's avatar
Myle Ott committed
2
#
3
4
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
5
6

"""
Myle Ott's avatar
Myle Ott committed
7
Train a network across multiple GPUs.
Myle Ott's avatar
Myle Ott committed
8
9
"""

Peng-Jen Chen's avatar
Peng-Jen Chen committed
10
from collections import OrderedDict
Myle Ott's avatar
Myle Ott committed
11
import contextlib
Sergey Edunov's avatar
Sergey Edunov committed
12
from itertools import chain
Myle Ott's avatar
Myle Ott committed
13
import math
Myle Ott's avatar
Myle Ott committed
14
import os
15
import sys
Myle Ott's avatar
Myle Ott committed
16

Myle Ott's avatar
Myle Ott committed
17
18
import torch

Myle Ott's avatar
Myle Ott committed
19
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
Myle Ott's avatar
Myle Ott committed
20
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
Myle Ott's avatar
Myle Ott committed
21
22
23
24
from fairseq.optim import lr_scheduler


class Trainer(object):
Myle Ott's avatar
Myle Ott committed
25
    """Main class for data parallel training.
Myle Ott's avatar
Myle Ott committed
26

27
28
29
30
31
    This class supports synchronous distributed data parallel training,
    where multiple workers each have a full model replica and gradients
    are accumulated across workers before each update. We use
    :class:`~torch.nn.parallel.DistributedDataParallel` to handle
    communication of the gradients across workers.
Myle Ott's avatar
Myle Ott committed
32
33
    """

Myle Ott's avatar
Myle Ott committed
34
    def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=None):
Myle Ott's avatar
Myle Ott committed
35
        self.args = args
36
        self.task = task
Myle Ott's avatar
Myle Ott committed
37
38

        # copy model and criterion to current device
Myle Ott's avatar
Myle Ott committed
39
40
41
        self.criterion = criterion
        self._model = model
        self.cuda = torch.cuda.is_available() and not args.cpu
42
        if args.fp16:
Myle Ott's avatar
Myle Ott committed
43
44
45
46
            self._model = self._model.half()
        if self.cuda:
            self.criterion = self.criterion.cuda()
            self._model = self._model.cuda()
Myle Ott's avatar
Myle Ott committed
47

48
        self._dummy_batch = dummy_batch
Myle Ott's avatar
Myle Ott committed
49
        self._oom_batch = oom_batch or dummy_batch
Myle Ott's avatar
Myle Ott committed
50
51

        self._lr_scheduler = None
52
53
54
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
55
        self._prev_grad_norm = None
56
57
58
59
60
        self._wrapped_model = None

        self.init_meters(args)

    def init_meters(self, args):
Myle Ott's avatar
Myle Ott committed
61
62
63
64
65
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
Myle Ott's avatar
Myle Ott committed
66
67
68
69
70
71
72
        self.meters['wps'] = TimeMeter()       # words per second
        self.meters['ups'] = TimeMeter()       # updates per second
        self.meters['wpb'] = AverageMeter()    # words per batch
        self.meters['bsz'] = AverageMeter()    # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()   # % of updates clipped
        self.meters['oom'] = AverageMeter()    # out of memory
73
74
        if args.fp16:
            self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
Myle Ott's avatar
Myle Ott committed
75
        self.meters['wall'] = TimeMeter()      # wall time in seconds
Myle Ott's avatar
Myle Ott committed
76
        self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
Myle Ott's avatar
Myle Ott committed
77

78
79
80
    @property
    def model(self):
        if self._wrapped_model is None:
Nayan Singhal's avatar
Nayan Singhal committed
81
            if self.args.distributed_world_size > 1 and not self.args.use_bmuf:
82
83
84
85
86
87
                self._wrapped_model = models.DistributedFairseqModel(
                    self.args, self._model,
                )
            else:
                self._wrapped_model = self._model
        return self._wrapped_model
Myle Ott's avatar
Myle Ott committed
88
89
90
91
92
93

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer
Myle Ott's avatar
Myle Ott committed
94

Myle Ott's avatar
Myle Ott committed
95
96
97
    @property
    def lr_scheduler(self):
        if self._lr_scheduler is None:
Myle Ott's avatar
Myle Ott committed
98
            self._build_optimizer()  # this will initialize self._lr_scheduler
Myle Ott's avatar
Myle Ott committed
99
100
        return self._lr_scheduler

Myle Ott's avatar
Myle Ott committed
101
    def _build_optimizer(self):
Myle Ott's avatar
Myle Ott committed
102
        params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
103
        if self.args.fp16:
Myle Ott's avatar
Myle Ott committed
104
            if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
105
106
                print('| WARNING: your device does NOT support faster training with --fp16, '
                      'please switch to FP32 which is likely to be faster')
Myle Ott's avatar
Myle Ott committed
107
108
109
110
            if self.args.memory_efficient_fp16:
                self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
            else:
                self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
111
        else:
Myle Ott's avatar
Myle Ott committed
112
            if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
113
                print('| NOTICE: your device may support faster training with --fp16')
Myle Ott's avatar
Myle Ott committed
114
            self._optimizer = optim.build_optimizer(self.args, params)
Myle Ott's avatar
Myle Ott committed
115

Nayan Singhal's avatar
Nayan Singhal committed
116
117
118
        if self.args.use_bmuf:
            self._optimizer = optim.FairseqBMUF(self.args, params, self._optimizer)

Myle Ott's avatar
Myle Ott committed
119
120
121
        # We should initialize the learning rate scheduler immediately after
        # building the optimizer, so that the initial learning rate is set.
        self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
122
        self._lr_scheduler.step_update(0)
Myle Ott's avatar
Myle Ott committed
123

Myle Ott's avatar
Myle Ott committed
124
125
    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
126
        if distributed_utils.is_master(self.args):  # only save one checkpoint
Myle Ott's avatar
Myle Ott committed
127
            extra_state['train_meters'] = self.meters
Myle Ott's avatar
Myle Ott committed
128
            checkpoint_utils.save_state(
Myle Ott's avatar
Myle Ott committed
129
                filename, self.args, self.get_model().state_dict(), self.criterion,
Myle Ott's avatar
Myle Ott committed
130
                self.optimizer, self.lr_scheduler, self.get_num_updates(),
Myle Ott's avatar
Myle Ott committed
131
                self._optim_history, extra_state,
Myle Ott's avatar
Nits  
Myle Ott committed
132
            )
Myle Ott's avatar
Myle Ott committed
133

Myle Ott's avatar
Myle Ott committed
134
135
136
137
138
139
140
141
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
Myle Ott's avatar
Myle Ott committed
142
        """Load all training state from a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        extra_state, self._optim_history, last_optim_state = None, [], None

        if os.path.exists(filename):
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(state['model'], strict=True)
            except Exception:
                raise Exception(
                    'Cannot load model parameters from checkpoint, '
                    'please ensure that the architectures match.'
                )

            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
Myle Ott's avatar
Myle Ott committed
159
            last_optim_state = state.get('last_optimizer_state', None)
Myle Ott's avatar
Myle Ott committed
160

161
        if last_optim_state is not None and not reset_optimizer:
Myle Ott's avatar
Myle Ott committed
162
            # rebuild optimizer after loading model, since params may have changed
Myle Ott's avatar
Myle Ott committed
163
            self._build_optimizer()
Myle Ott's avatar
Myle Ott committed
164

165
166
167
            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
168
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
169
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
170
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
171
172
173
174
175

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

Myle Ott's avatar
Myle Ott committed
176
            self.set_num_updates(last_optim['num_updates'])
Myle Ott's avatar
Myle Ott committed
177

Myle Ott's avatar
Myle Ott committed
178
179
180
181
        if extra_state is not None:
            epoch = extra_state['train_iterator']['epoch']
            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                filename, epoch, self.get_num_updates()))
Myle Ott's avatar
Myle Ott committed
182

Myle Ott's avatar
Myle Ott committed
183
184
            self.lr_step(epoch)

Myle Ott's avatar
Myle Ott committed
185
            if 'train_meters' in extra_state and not reset_meters:
Myle Ott's avatar
Myle Ott committed
186
187
188
189
190
191
192
193
194
                self.meters.update(extra_state['train_meters'])
                del extra_state['train_meters']

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in self.meters.values():
                    if isinstance(meter, TimeMeter):
                        meter.reset()
        else:
            print('| no existing checkpoint found {}'.format(filename))
Myle Ott's avatar
Myle Ott committed
195

Myle Ott's avatar
Myle Ott committed
196
197
        return extra_state

Myle Ott's avatar
Myle Ott committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def get_train_iterator(self, epoch, combine=True):
        """Return an EpochBatchIterator over the training set for a given epoch."""
        print('| loading train data for epoch {}'.format(epoch))
        self.task.load_dataset(self.args.train_subset, epoch=epoch, combine=combine)
        return self.task.get_batch_iterator(
            dataset=self.task.dataset(self.args.train_subset),
            max_tokens=self.args.max_tokens,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(
                self.task.max_positions(),
                self.model.max_positions(),
            ),
            ignore_invalid_inputs=True,
            required_batch_size_multiple=self.args.required_batch_size_multiple,
            seed=self.args.seed,
            num_shards=self.args.distributed_world_size,
            shard_id=self.args.distributed_rank,
            num_workers=self.args.num_workers,
            epoch=epoch,
        )

219
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
220
        """Do forward, backward and parameter update."""
Myle Ott's avatar
Myle Ott committed
221
222
223
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

224
        self._set_seed()
225
        self.model.train()
Myle Ott's avatar
Myle Ott committed
226
        self.criterion.train()
227
228
        self.zero_grad()

Myle Ott's avatar
Myle Ott committed
229
230
231
        if not dummy_batch:
            self.meters['train_wall'].start()

Sergey Edunov's avatar
Sergey Edunov committed
232
        # forward and backward pass
233
234
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
Myle Ott's avatar
Myle Ott committed
235
            sample = self._prepare_sample(sample)
236
237
238
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
Myle Ott's avatar
Myle Ott committed
239
                sample = self._prepare_sample(self._dummy_batch)
240
241
242
                ignore_grad = True
            else:
                ignore_grad = False
Myle Ott's avatar
Myle Ott committed
243

Myle Ott's avatar
Myle Ott committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.args.distributed_world_size > 1
                    and hasattr(self.model, 'no_sync')
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

259
            try:
Myle Ott's avatar
Myle Ott committed
260
261
262
263
264
265
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad
                    )
266
267
268
269
270
271

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
272
273
274
275
276
277
278
279
280
281
282
283
                    msg = (
                        '| WARNING: ran out of memory with exception: '
                        + '{};'.format(e)
                        + '\n Skipping batch'
                    )
                    # TODO: print should really go to logger, this print goes
                    # to stdout, which is buffered, which in many case is not
                    # printed out if another exception happens
                    # print(msg)
                    print(msg, file=sys.stderr)
                    if raise_oom:
                        raise ValueError(msg)
284
285
286
287
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e
Myle Ott's avatar
Myle Ott committed
288

289
290
291
        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

292
293
        if dummy_batch:
            return None
Myle Ott's avatar
Myle Ott committed
294
295

        # gather logging outputs from all replicas
Nayan Singhal's avatar
Nayan Singhal committed
296
297
298
299
300
301
302
        if self.args.distributed_world_size > 1 and (
            (not self.args.use_bmuf)
            or (
                self.args.use_bmuf
                and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
            )
        ):
303
304
305
306
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
307
308
309
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)
Nayan Singhal's avatar
Nayan Singhal committed
310
311
312
313
314
315

            if not self.args.use_bmuf:
                assert (
                    all(norm == prev_norms[0] for norm in prev_norms)
                    or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
                ), 'Fatal error: gradients are inconsistent between workers'
Myle Ott's avatar
Myle Ott committed
316

317
        self.meters['oom'].update(ooms, len(samples))
318
        if ooms == self.args.distributed_world_size * len(samples):
319
            print('| WARNING: OOM in all workers, skipping update')
Myle Ott's avatar
Myle Ott committed
320
321
322
            self.zero_grad()
            return None

323
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
324
325
326
        logging_output = self.task.aggregate_logging_outputs(
            logging_outputs, self.criterion
        )
Myle Ott's avatar
Myle Ott committed
327
        sample_size = self.task.grad_denom(sample_sizes, self.criterion)
328
329
330
331
332

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
Peng-Jen Chen's avatar
Peng-Jen Chen committed
333
            ).format(self.task.__class__.__name__))
Myle Ott's avatar
Myle Ott committed
334
335

        try:
336
            # normalize grads by sample size
Nayan Singhal's avatar
Nayan Singhal committed
337
338
            if sample_size > 0:
                self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
339
340
341

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
342
            self._prev_grad_norm = grad_norm
343
344
345

            # take an optimization step
            self.optimizer.step()
Myle Ott's avatar
Myle Ott committed
346
            self.set_num_updates(self.get_num_updates() + 1)
Myle Ott's avatar
Myle Ott committed
347

348
349
350
            # task specific update per step
            self.task.update_step(self._num_updates)

Myle Ott's avatar
Myle Ott committed
351
            # update meters
352
353
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
Myle Ott's avatar
Myle Ott committed
354
355
356
357
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
358
359
360
361
362
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
363
364
365
366
            if 'train_acc' in self.meters:
                self.meters['train_acc'].update(
                    logging_output.get('acc', 0), sample_size)

367
368
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
369
370
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
371
372
            self.zero_grad()
            logging_output = None
Myle Ott's avatar
Myle Ott committed
373

374
375
376
        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
Myle Ott's avatar
Myle Ott committed
377

378
        self.meters['train_wall'].stop()
Myle Ott's avatar
Myle Ott committed
379

380
        return logging_output
Myle Ott's avatar
Myle Ott committed
381

382
    def valid_step(self, sample, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
383
        """Do forward pass in evaluation mode."""
384
        with torch.no_grad():
Myle Ott's avatar
Myle Ott committed
385
            self.model.eval()
Myle Ott's avatar
Myle Ott committed
386
            self.criterion.eval()
Myle Ott's avatar
Myle Ott committed
387

Myle Ott's avatar
Myle Ott committed
388
            sample = self._prepare_sample(sample)
389
            if sample is None:
Myle Ott's avatar
Myle Ott committed
390
                sample = self._prepare_sample(self._dummy_batch)
Myle Ott's avatar
Myle Ott committed
391
392
393
394
                ignore_results = True
            else:
                ignore_results = False

395
            try:
Peng-Jen Chen's avatar
Peng-Jen Chen committed
396
397
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion
398
399
400
401
402
403
                )
            except RuntimeError as e:
                if 'out of memory' in str(e) and not raise_oom:
                    print('| WARNING: ran out of memory, retrying batch')
                    for p in self.model.parameters():
                        if p.grad is not None:
Myle Ott's avatar
Myle Ott committed
404
                            p.grad = None  # free some memory
Myle Ott's avatar
Myle Ott committed
405
406
                    if self.cuda:
                        torch.cuda.empty_cache()
407
408
409
                    return self.valid_step(sample, raise_oom=True)
                else:
                    raise e
Sergey Edunov's avatar
Sergey Edunov committed
410

Myle Ott's avatar
Myle Ott committed
411
412
413
            if ignore_results:
                logging_output, sample_size = {}, 0

414
        # gather logging outputs from all replicas
Sergey Edunov's avatar
Sergey Edunov committed
415
        if self.args.distributed_world_size > 1:
416
417
            logging_output, sample_size = zip(*distributed_utils.all_gather_list(
                [logging_output, sample_size],
Sergey Edunov's avatar
Sergey Edunov committed
418
            ))
419
420
            logging_output = list(logging_output)
            sample_size = list(sample_size)
Sergey Edunov's avatar
Sergey Edunov committed
421
        else:
422
423
            logging_output = [logging_output]
            sample_size = [sample_size]
Myle Ott's avatar
Myle Ott committed
424

425
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
426
427
428
429
430
431
        logging_output = self.task.aggregate_logging_outputs(
            logging_output, self.criterion
        )
        sample_size = self.task.grad_denom(
            sample_size, self.criterion
        )
Myle Ott's avatar
Myle Ott committed
432

433
434
435
        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)
        self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
436
437
438
439
        if 'valid_acc' in self.meters:
            self.meters['valid_acc'].update(
                logging_output.get('acc', 0), sample_size)

440
441
        if 'nll_loss' in logging_output:
            self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
442

443
        return logging_output
Myle Ott's avatar
Myle Ott committed
444

Myle Ott's avatar
Myle Ott committed
445
446
    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
447
        self.train_step(dummy_batch, dummy_batch=True)
Myle Ott's avatar
Myle Ott committed
448
449
        self.zero_grad()

450
451
452
453
454
455
456
457
458
    def handle_ooms(self, number_of_ooms):
        """
        c10d accumulates/syncs gradients between gpus during backward pass.
        In case of OOMs, gpus may fail to sync, so we manually iterate
        extra to make sure each gpu makes same number of iterations.
        """
        for _ in range(number_of_ooms):
            self.train_step([self._oom_batch], True)

Myle Ott's avatar
Myle Ott committed
459
460
461
    def zero_grad(self):
        self.optimizer.zero_grad()

Myle Ott's avatar
Myle Ott committed
462
463
    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
Nayan Singhal's avatar
Nayan Singhal committed
464
        self.lr_scheduler.step(epoch, val_loss)
Myle Ott's avatar
Myle Ott committed
465
466
        # prefer updating the LR based on the number of steps
        return self.lr_step_update()
Myle Ott's avatar
Myle Ott committed
467

Myle Ott's avatar
Myle Ott committed
468
    def lr_step_update(self):
Myle Ott's avatar
Myle Ott committed
469
        """Update the learning rate after each update."""
Myle Ott's avatar
Myle Ott committed
470
        return self.lr_scheduler.step_update(self.get_num_updates())
Myle Ott's avatar
Myle Ott committed
471

Myle Ott's avatar
Myle Ott committed
472
473
474
475
476
    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
477
478
        """Get the (non-wrapped) model instance."""
        return self._model
Myle Ott's avatar
Myle Ott committed
479
480
481
482
483
484
485
486
487
488
489

    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

Myle Ott's avatar
Myle Ott committed
490
491
492
493
494
    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        self._num_updates = num_updates
        self.lr_step_update()

Myle Ott's avatar
Myle Ott committed
495
    def _prepare_sample(self, sample):
Myle Ott's avatar
Myle Ott committed
496
497
        if sample is None or len(sample) == 0:
            return None
alexeib's avatar
alexeib committed
498

Myle Ott's avatar
Myle Ott committed
499
500
        if self.cuda:
            sample = utils.move_to_cuda(sample)
alexeib's avatar
alexeib committed
501
502
503
504
505
506

        def apply_half(t):
            if t.dtype is torch.float32:
                return t.half()
            return t

Myle Ott's avatar
Myle Ott committed
507
508
509
510
        if self.args.fp16:
            sample = utils.apply_to_sample(apply_half, sample)

        return sample
511
512
513
514
515
516
517
518

    def _set_seed(self):
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        if self.cuda:
            torch.cuda.manual_seed(seed)