trainer.py 16.9 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

"""
Myle Ott's avatar
Myle Ott committed
9
Train a network across multiple GPUs.
Myle Ott's avatar
Myle Ott committed
10
11
"""

Peng-Jen Chen's avatar
Peng-Jen Chen committed
12
from collections import OrderedDict
Sergey Edunov's avatar
Sergey Edunov committed
13
from itertools import chain
Myle Ott's avatar
Myle Ott committed
14
import os
Myle Ott's avatar
Myle Ott committed
15

Myle Ott's avatar
Myle Ott committed
16
17
import torch

Myle Ott's avatar
Myle Ott committed
18
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
Myle Ott's avatar
Myle Ott committed
19
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
Myle Ott's avatar
Myle Ott committed
20
21
22
23
from fairseq.optim import lr_scheduler


class Trainer(object):
Myle Ott's avatar
Myle Ott committed
24
    """Main class for data parallel training.
Myle Ott's avatar
Myle Ott committed
25

26
27
28
29
30
    This class supports synchronous distributed data parallel training,
    where multiple workers each have a full model replica and gradients
    are accumulated across workers before each update. We use
    :class:`~torch.nn.parallel.DistributedDataParallel` to handle
    communication of the gradients across workers.
Myle Ott's avatar
Myle Ott committed
31
32
    """

33
    def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
Myle Ott's avatar
Myle Ott committed
34
        self.args = args
35
        self.task = task
Myle Ott's avatar
Myle Ott committed
36
37

        # copy model and criterion to current device
Myle Ott's avatar
Myle Ott committed
38
39
40
        self.criterion = criterion
        self._model = model
        self.cuda = torch.cuda.is_available() and not args.cpu
41
        if args.fp16:
Myle Ott's avatar
Myle Ott committed
42
43
44
45
            self._model = self._model.half()
        if self.cuda:
            self.criterion = self.criterion.cuda()
            self._model = self._model.cuda()
Myle Ott's avatar
Myle Ott committed
46

47
        self._dummy_batch = dummy_batch
48
        self._oom_batch = oom_batch
Myle Ott's avatar
Myle Ott committed
49
50

        self._lr_scheduler = None
51
52
53
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
54
        self._prev_grad_norm = None
55
56
57
58
59
        self._wrapped_model = None

        self.init_meters(args)

    def init_meters(self, args):
Myle Ott's avatar
Myle Ott committed
60
61
62
63
64
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
Myle Ott's avatar
Myle Ott committed
65
66
67
68
69
70
71
        self.meters['wps'] = TimeMeter()       # words per second
        self.meters['ups'] = TimeMeter()       # updates per second
        self.meters['wpb'] = AverageMeter()    # words per batch
        self.meters['bsz'] = AverageMeter()    # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()   # % of updates clipped
        self.meters['oom'] = AverageMeter()    # out of memory
72
73
        if args.fp16:
            self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
Myle Ott's avatar
Myle Ott committed
74
        self.meters['wall'] = TimeMeter()      # wall time in seconds
Myle Ott's avatar
Myle Ott committed
75
        self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
Myle Ott's avatar
Myle Ott committed
76

77
78
79
80
81
82
83
84
85
86
    @property
    def model(self):
        if self._wrapped_model is None:
            if self.args.distributed_world_size > 1:
                self._wrapped_model = models.DistributedFairseqModel(
                    self.args, self._model,
                )
            else:
                self._wrapped_model = self._model
        return self._wrapped_model
Myle Ott's avatar
Myle Ott committed
87
88
89
90
91
92

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer
Myle Ott's avatar
Myle Ott committed
93

Myle Ott's avatar
Myle Ott committed
94
95
96
    @property
    def lr_scheduler(self):
        if self._lr_scheduler is None:
Myle Ott's avatar
Myle Ott committed
97
            self._build_optimizer()  # this will initialize self._lr_scheduler
Myle Ott's avatar
Myle Ott committed
98
99
        return self._lr_scheduler

Myle Ott's avatar
Myle Ott committed
100
    def _build_optimizer(self):
Myle Ott's avatar
Myle Ott committed
101
        params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
102
        if self.args.fp16:
Myle Ott's avatar
Myle Ott committed
103
            if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
104
105
                print('| WARNING: your device does NOT support faster training with --fp16, '
                      'please switch to FP32 which is likely to be faster')
Myle Ott's avatar
Myle Ott committed
106
107
108
109
            if self.args.memory_efficient_fp16:
                self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
            else:
                self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
110
        else:
Myle Ott's avatar
Myle Ott committed
111
            if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
112
                print('| NOTICE: your device may support faster training with --fp16')
Myle Ott's avatar
Myle Ott committed
113
            self._optimizer = optim.build_optimizer(self.args, params)
Myle Ott's avatar
Myle Ott committed
114

Myle Ott's avatar
Myle Ott committed
115
116
117
118
        # We should initialize the learning rate scheduler immediately after
        # building the optimizer, so that the initial learning rate is set.
        self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)

Myle Ott's avatar
Myle Ott committed
119
120
    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
121
        if distributed_utils.is_master(self.args):  # only save one checkpoint
Myle Ott's avatar
Myle Ott committed
122
            extra_state['train_meters'] = self.meters
Myle Ott's avatar
Myle Ott committed
123
            checkpoint_utils.save_state(
124
                filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
Myle Ott's avatar
Nits  
Myle Ott committed
125
126
                self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
            )
Myle Ott's avatar
Myle Ott committed
127

128
    def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
Myle Ott's avatar
Myle Ott committed
129
        """Load all training state from a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        extra_state, self._optim_history, last_optim_state = None, [], None

        if os.path.exists(filename):
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(state['model'], strict=True)
            except Exception:
                raise Exception(
                    'Cannot load model parameters from checkpoint, '
                    'please ensure that the architectures match.'
                )

            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
            last_optim_state = state['last_optimizer_state']

148
        if last_optim_state is not None and not reset_optimizer:
Myle Ott's avatar
Myle Ott committed
149
            # rebuild optimizer after loading model, since params may have changed
Myle Ott's avatar
Myle Ott committed
150
            self._build_optimizer()
Myle Ott's avatar
Myle Ott committed
151

152
153
154
            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
155
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
156
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
Myle Ott's avatar
Myle Ott committed
157
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
158
159
160
161
162
163

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

            self._num_updates = last_optim['num_updates']
Myle Ott's avatar
Myle Ott committed
164

Myle Ott's avatar
Myle Ott committed
165
        if extra_state is not None and 'train_meters' in extra_state:
166
            self.meters.update(extra_state['train_meters'])
Myle Ott's avatar
Myle Ott committed
167
168
            del extra_state['train_meters']

Myle Ott's avatar
Myle Ott committed
169
170
171
172
173
            # reset TimeMeters, since their start times don't make sense anymore
            for meter in self.meters.values():
                if isinstance(meter, TimeMeter):
                    meter.reset()

Myle Ott's avatar
Myle Ott committed
174
175
        return extra_state

176
    def train_step(self, samples, dummy_batch=False):
Myle Ott's avatar
Myle Ott committed
177
        """Do forward, backward and parameter update."""
178
        self._set_seed()
179
        self.model.train()
Myle Ott's avatar
Myle Ott committed
180
        self.criterion.train()
181
182
        self.zero_grad()

Myle Ott's avatar
Myle Ott committed
183
184
185
        if not dummy_batch:
            self.meters['train_wall'].start()

Sergey Edunov's avatar
Sergey Edunov committed
186
        # forward and backward pass
187
188
189
190
191
192
193
194
195
196
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False
Myle Ott's avatar
Myle Ott committed
197

198
            try:
199
200
201
202
                if self.args.distributed_world_size > 1:
                    # Whenever *samples* contains more than one mini-batch, we
                    # want to accumulate gradients locally and only call
                    # all-reduce in the last backwards pass. Currently the
Myle Ott's avatar
Myle Ott committed
203
                    # *accumulate_grads* flag is only supported by
204
205
206
207
208
209
                    # LegacyDistributedDataParallel.
                    if i < len(samples) - 1:
                        self.model.accumulate_grads = True
                    else:
                        self.model.accumulate_grads = False

Peng-Jen Chen's avatar
Peng-Jen Chen committed
210
211
212
213
                # forward and backward
                loss, sample_size, logging_output = self.task.train_step(
                    sample, self.model, self.criterion, self.optimizer,
                    ignore_grad
214
215
216
217
218
219
220
                )

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
221
                    print(('| WARNING: ran out of memory with exception: {};\n Skipping batch').format(str(e)))
222
223
224
225
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e
Myle Ott's avatar
Myle Ott committed
226

227
228
229
        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

230
231
        if dummy_batch:
            return None
Myle Ott's avatar
Myle Ott committed
232
233
234

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
235
236
237
238
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
239
240
241
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)
242
243
            assert all(norm == prev_norms[0] for norm in prev_norms), \
                'Fatal error: gradients are inconsistent between workers'
Myle Ott's avatar
Myle Ott committed
244

245
        self.meters['oom'].update(ooms, len(samples))
246
        if ooms == self.args.distributed_world_size * len(samples):
247
            print('| WARNING: OOM in all workers, skipping update')
Myle Ott's avatar
Myle Ott committed
248
249
250
            self.zero_grad()
            return None

251
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
252
253
254
        logging_output = self.task.aggregate_logging_outputs(
            logging_outputs, self.criterion
        )
Myle Ott's avatar
Myle Ott committed
255
        sample_size = self.task.grad_denom(sample_sizes, self.criterion)
256
257
258
259
260

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
Peng-Jen Chen's avatar
Peng-Jen Chen committed
261
            ).format(self.task.__class__.__name__))
Myle Ott's avatar
Myle Ott committed
262
263

        try:
264
265
266
267
268
            # normalize grads by sample size
            self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
269
            self._prev_grad_norm = grad_norm
270
271
272
273
274
275
276

            # take an optimization step
            self.optimizer.step()
            self._num_updates += 1

            # update learning rate
            self.lr_scheduler.step_update(self._num_updates)
Myle Ott's avatar
Myle Ott committed
277

278
279
280
            # task specific update per step
            self.task.update_step(self._num_updates)

Myle Ott's avatar
Myle Ott committed
281
            # update meters
282
283
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
Myle Ott's avatar
Myle Ott committed
284
285
286
287
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
288
289
290
291
292
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
293
294
295
296
            if 'train_acc' in self.meters:
                self.meters['train_acc'].update(
                    logging_output.get('acc', 0), sample_size)

297
298
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
299
300
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
301
302
            self.zero_grad()
            logging_output = None
Myle Ott's avatar
Myle Ott committed
303

304
305
306
        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
Myle Ott's avatar
Myle Ott committed
307

308
        self.meters['train_wall'].stop()
Myle Ott's avatar
Myle Ott committed
309

310
        return logging_output
Myle Ott's avatar
Myle Ott committed
311

312
    def valid_step(self, sample, raise_oom=False):
Myle Ott's avatar
Myle Ott committed
313
        """Do forward pass in evaluation mode."""
314
        with torch.no_grad():
Myle Ott's avatar
Myle Ott committed
315
            self.model.eval()
Myle Ott's avatar
Myle Ott committed
316
            self.criterion.eval()
Myle Ott's avatar
Myle Ott committed
317

318
319
320
            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
Myle Ott's avatar
Myle Ott committed
321
322
323
324
                ignore_results = True
            else:
                ignore_results = False

325
            try:
Peng-Jen Chen's avatar
Peng-Jen Chen committed
326
327
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion
328
329
330
331
332
333
334
                )
            except RuntimeError as e:
                if 'out of memory' in str(e) and not raise_oom:
                    print('| WARNING: ran out of memory, retrying batch')
                    for p in self.model.parameters():
                        if p.grad is not None:
                            del p.grad  # free some memory
Myle Ott's avatar
Myle Ott committed
335
336
                    if self.cuda:
                        torch.cuda.empty_cache()
337
338
339
                    return self.valid_step(sample, raise_oom=True)
                else:
                    raise e
Sergey Edunov's avatar
Sergey Edunov committed
340

Myle Ott's avatar
Myle Ott committed
341
342
343
            if ignore_results:
                logging_output, sample_size = {}, 0

344
        # gather logging outputs from all replicas
Sergey Edunov's avatar
Sergey Edunov committed
345
        if self.args.distributed_world_size > 1:
346
347
            logging_output, sample_size = zip(*distributed_utils.all_gather_list(
                [logging_output, sample_size],
Sergey Edunov's avatar
Sergey Edunov committed
348
            ))
349
350
            logging_output = list(logging_output)
            sample_size = list(sample_size)
Sergey Edunov's avatar
Sergey Edunov committed
351
        else:
352
353
            logging_output = [logging_output]
            sample_size = [sample_size]
Myle Ott's avatar
Myle Ott committed
354

355
        # aggregate logging outputs and sample sizes
Peng-Jen Chen's avatar
Peng-Jen Chen committed
356
357
358
359
360
361
        logging_output = self.task.aggregate_logging_outputs(
            logging_output, self.criterion
        )
        sample_size = self.task.grad_denom(
            sample_size, self.criterion
        )
Myle Ott's avatar
Myle Ott committed
362

363
364
365
        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)
        self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
366
367
368
369
        if 'valid_acc' in self.meters:
            self.meters['valid_acc'].update(
                logging_output.get('acc', 0), sample_size)

370
371
        if 'nll_loss' in logging_output:
            self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
372

373
        return logging_output
Myle Ott's avatar
Myle Ott committed
374

Myle Ott's avatar
Myle Ott committed
375
376
    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
377
        self.train_step(dummy_batch, dummy_batch=True)
Myle Ott's avatar
Myle Ott committed
378
379
        self.zero_grad()

380
381
382
383
384
385
386
387
388
    def handle_ooms(self, number_of_ooms):
        """
        c10d accumulates/syncs gradients between gpus during backward pass.
        In case of OOMs, gpus may fail to sync, so we manually iterate
        extra to make sure each gpu makes same number of iterations.
        """
        for _ in range(number_of_ooms):
            self.train_step([self._oom_batch], True)

Myle Ott's avatar
Myle Ott committed
389
390
391
    def zero_grad(self):
        self.optimizer.zero_grad()

Myle Ott's avatar
Myle Ott committed
392
393
394
395
    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
        return self.lr_scheduler.step(epoch, val_loss)

Myle Ott's avatar
Myle Ott committed
396
397
398
399
    def lr_step_update(self, num_updates):
        """Update the learning rate after each update."""
        return self.lr_scheduler.step_update(num_updates)

Myle Ott's avatar
Myle Ott committed
400
401
402
403
404
    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
405
406
        """Get the (non-wrapped) model instance."""
        return self._model
Myle Ott's avatar
Myle Ott committed
407
408
409
410
411
412
413
414
415
416
417

    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

Myle Ott's avatar
Myle Ott committed
418
    def _prepare_sample(self, sample):
Myle Ott's avatar
Myle Ott committed
419
420
        if sample is None or len(sample) == 0:
            return None
Myle Ott's avatar
Myle Ott committed
421
422
423
        if self.cuda:
            sample = utils.move_to_cuda(sample)
        return sample
424
425
426
427
428
429
430
431

    def _set_seed(self):
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        if self.cuda:
            torch.cuda.manual_seed(seed)