trainer.py 13.2 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

"""
Myle Ott's avatar
Myle Ott committed
9
Train a network across multiple GPUs.
Myle Ott's avatar
Myle Ott committed
10
11
"""

Sergey Edunov's avatar
Sergey Edunov committed
12
from collections import defaultdict, OrderedDict
Myle Ott's avatar
Myle Ott committed
13
import contextlib
Sergey Edunov's avatar
Sergey Edunov committed
14
from itertools import chain
Myle Ott's avatar
Myle Ott committed
15

Myle Ott's avatar
Myle Ott committed
16
17
import torch

18
from fairseq import distributed_utils, models, optim, utils
Myle Ott's avatar
Myle Ott committed
19
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
Myle Ott's avatar
Myle Ott committed
20
21
22
23
from fairseq.optim import lr_scheduler


class Trainer(object):
Myle Ott's avatar
Myle Ott committed
24
    """Main class for data parallel training.
Myle Ott's avatar
Myle Ott committed
25

26
27
28
29
30
    This class supports synchronous distributed data parallel training,
    where multiple workers each have a full model replica and gradients
    are accumulated across workers before each update. We use
    :class:`~torch.nn.parallel.DistributedDataParallel` to handle
    communication of the gradients across workers.
Myle Ott's avatar
Myle Ott committed
31
32
    """

33
    def __init__(self, args, task, model, criterion, dummy_batch):
Myle Ott's avatar
Myle Ott committed
34
35
36
37
38

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args
39
        self.task = task
Myle Ott's avatar
Myle Ott committed
40
41
42

        # copy model and criterion to current device
        self.criterion = criterion.cuda()
43
44
45
46
        if args.fp16:
            self._model = model.half().cuda()
        else:
            self._model = model.cuda()
Myle Ott's avatar
Myle Ott committed
47
48
49
50
51
52
53

        # initialize meters
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
Myle Ott's avatar
Myle Ott committed
54
55
56
57
58
59
60
        self.meters['wps'] = TimeMeter()       # words per second
        self.meters['ups'] = TimeMeter()       # updates per second
        self.meters['wpb'] = AverageMeter()    # words per batch
        self.meters['bsz'] = AverageMeter()    # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()   # % of updates clipped
        self.meters['oom'] = AverageMeter()    # out of memory
61
62
        if args.fp16:
            self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
Myle Ott's avatar
Myle Ott committed
63
        self.meters['wall'] = TimeMeter()      # wall time in seconds
Myle Ott's avatar
Myle Ott committed
64
        self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
Myle Ott's avatar
Myle Ott committed
65

66
        self._dummy_batch = dummy_batch
Myle Ott's avatar
Myle Ott committed
67
        self._num_updates = 0
alexeib's avatar
alexeib committed
68
        self._optim_history = None
Myle Ott's avatar
Myle Ott committed
69
        self._optimizer = None
70
71
72
73
74
75
76
77
78
79
80
81
        self._wrapped_model = None

    @property
    def model(self):
        if self._wrapped_model is None:
            if self.args.distributed_world_size > 1:
                self._wrapped_model = models.DistributedFairseqModel(
                    self.args, self._model,
                )
            else:
                self._wrapped_model = self._model
        return self._wrapped_model
Myle Ott's avatar
Myle Ott committed
82
83
84
85
86
87

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer
Myle Ott's avatar
Myle Ott committed
88

Myle Ott's avatar
Myle Ott committed
89
    def _build_optimizer(self):
90
91
92
93
94
95
96
97
98
99
100
        if self.args.fp16:
            if torch.cuda.get_device_capability(0)[0] < 7:
                print('| WARNING: your device does NOT support faster training with --fp16, '
                      'please switch to FP32 which is likely to be faster')
            params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
            self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
        else:
            if torch.cuda.get_device_capability(0)[0] >= 7:
                print('| NOTICE: your device may support faster training with --fp16')
            self._optimizer = optim.build_optimizer(self.args, self.model.parameters())

Myle Ott's avatar
Myle Ott committed
101
        self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
Myle Ott's avatar
Myle Ott committed
102

Myle Ott's avatar
Myle Ott committed
103
104
    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
105
        if distributed_utils.is_master(self.args):  # only save one checkpoint
Myle Ott's avatar
Myle Ott committed
106
            extra_state['train_meters'] = self.meters
Myle Ott's avatar
Nits  
Myle Ott committed
107
            utils.save_state(
108
                filename, self.args, self.get_model(), self.criterion, self.optimizer,
Myle Ott's avatar
Nits  
Myle Ott committed
109
110
                self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
            )
Myle Ott's avatar
Myle Ott committed
111

112
    def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
Myle Ott's avatar
Myle Ott committed
113
        """Load all training state from a checkpoint file."""
114
        extra_state, self._optim_history, last_optim_state = \
115
            utils.load_model_state(filename, self.get_model())
116
        if last_optim_state is not None and not reset_optimizer:
Myle Ott's avatar
Myle Ott committed
117
            # rebuild optimizer after loading model, since params may have changed
Myle Ott's avatar
Myle Ott committed
118
            self._build_optimizer()
Myle Ott's avatar
Myle Ott committed
119

120
121
122
123
124
125
126
127
128
129
130
131
            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
                'criterion does not match; please reset the optimizer (--reset-optimizer)'
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
                'optimizer does not match; please reset the optimizer (--reset-optimizer)'

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

            self._num_updates = last_optim['num_updates']
Myle Ott's avatar
Myle Ott committed
132

Myle Ott's avatar
Myle Ott committed
133
        if extra_state is not None and 'train_meters' in extra_state:
134
            self.meters.update(extra_state['train_meters'])
Myle Ott's avatar
Myle Ott committed
135
136
            del extra_state['train_meters']

Myle Ott's avatar
Myle Ott committed
137
138
139
140
141
            # reset TimeMeters, since their start times don't make sense anymore
            for meter in self.meters.values():
                if isinstance(meter, TimeMeter):
                    meter.reset()

Myle Ott's avatar
Myle Ott committed
142
143
        return extra_state

144
    def train_step(self, samples, dummy_batch=False):
Myle Ott's avatar
Myle Ott committed
145
        """Do forward, backward and parameter update."""
Myle Ott's avatar
Myle Ott committed
146
147
148
149
150
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
Myle Ott's avatar
Myle Ott committed
151

152
153
154
        self.model.train()
        self.zero_grad()

Myle Ott's avatar
Myle Ott committed
155
156
157
        if not dummy_batch:
            self.meters['train_wall'].start()

Sergey Edunov's avatar
Sergey Edunov committed
158
        # forward and backward pass
159
160
161
162
163
164
165
166
167
168
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False
Myle Ott's avatar
Myle Ott committed
169

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            try:
                # forward
                loss, sample_size, logging_output = self.task.get_loss(
                    self.model, self.criterion, sample,
                )
                if ignore_grad:
                    loss *= 0

                if self.args.distributed_world_size > 1:
                    # only all-reduce gradients in the last backwards pass
                    if i < len(samples) - 1:
                        self.model.need_reduction = False
                    else:
                        self.model.need_reduction = True

                # backward
                self.optimizer.backward(loss)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e
Myle Ott's avatar
Myle Ott committed
198

199
200
        if dummy_batch:
            return None
Myle Ott's avatar
Myle Ott committed
201
202
203

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
204
205
206
207
208
209
            logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
                [logging_outputs, sample_sizes, ooms],
            ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)
Myle Ott's avatar
Myle Ott committed
210

211
212
        if ooms == self.args.distributed_world_size:
            print('| WARNING: OOM in all workers, skipping update')
Myle Ott's avatar
Myle Ott committed
213
214
215
            self.zero_grad()
            return None

216
217
218
219
220
221
222
223
224
        # aggregate logging outputs and sample sizes
        logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
        sample_size = self.criterion.__class__.grad_denom(sample_sizes)

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
            ).format(self.criterion.__class__.__name__))
Myle Ott's avatar
Myle Ott committed
225
226

        try:
227
228
229
230
231
232
233
234
235
236
237
238
            # normalize grads by sample size
            self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # take an optimization step
            self.optimizer.step()
            self._num_updates += 1

            # update learning rate
            self.lr_scheduler.step_update(self._num_updates)
Myle Ott's avatar
Myle Ott committed
239
240

            # update meters
241
242
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
Myle Ott's avatar
Myle Ott committed
243
244
245
246
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
247
248
249
250
251
252
253
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['oom'].update(ooms)
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
            self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
254
255
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
256
257
            self.zero_grad()
            logging_output = None
Myle Ott's avatar
Myle Ott committed
258

259
260
261
        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
Myle Ott's avatar
Myle Ott committed
262

263
        self.meters['train_wall'].stop()
Myle Ott's avatar
Myle Ott committed
264

265
        return logging_output
Myle Ott's avatar
Myle Ott committed
266
267
268

    def valid_step(self, sample):
        """Do forward pass in evaluation mode."""
269
270
271
272
273
274
275
276
277
278
        self.model.eval()

        logging_output, sample_size = {}, 0
        with torch.no_grad():
            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
            _loss, sample_size, logging_output = self.task.get_loss(
                self.model, self.criterion, sample,
            )
Sergey Edunov's avatar
Sergey Edunov committed
279

280
        # gather logging outputs from all replicas
Sergey Edunov's avatar
Sergey Edunov committed
281
        if self.args.distributed_world_size > 1:
282
283
            logging_output, sample_size = zip(*distributed_utils.all_gather_list(
                [logging_output, sample_size],
Sergey Edunov's avatar
Sergey Edunov committed
284
            ))
285
286
            logging_output = list(logging_output)
            sample_size = list(sample_size)
Sergey Edunov's avatar
Sergey Edunov committed
287
        else:
288
289
            logging_output = [logging_output]
            sample_size = [sample_size]
Myle Ott's avatar
Myle Ott committed
290

291
292
293
        # aggregate logging outputs and sample sizes
        logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_output)
        sample_size = self.criterion.__class__.grad_denom(sample_size)
Myle Ott's avatar
Myle Ott committed
294

295
296
297
298
        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)
        self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
        self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
Myle Ott's avatar
Myle Ott committed
299

300
        return logging_output
Myle Ott's avatar
Myle Ott committed
301

Myle Ott's avatar
Myle Ott committed
302
303
    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
304
        self.train_step(dummy_batch, dummy_batch=True)
Myle Ott's avatar
Myle Ott committed
305
306
307
308
309
        self.zero_grad()

    def zero_grad(self):
        self.optimizer.zero_grad()

Myle Ott's avatar
Myle Ott committed
310
311
312
313
    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
        return self.lr_scheduler.step(epoch, val_loss)

Myle Ott's avatar
Myle Ott committed
314
315
316
317
    def lr_step_update(self, num_updates):
        """Update the learning rate after each update."""
        return self.lr_scheduler.step_update(num_updates)

Myle Ott's avatar
Myle Ott committed
318
319
320
321
322
    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
323
324
        """Get the (non-wrapped) model instance."""
        return self._model
Myle Ott's avatar
Myle Ott committed
325
326
327
328
329
330
331
332
333
334
335

    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

Myle Ott's avatar
Myle Ott committed
336
    def _prepare_sample(self, sample):
Myle Ott's avatar
Myle Ott committed
337
338
        if sample is None or len(sample) == 0:
            return None
Myle Ott's avatar
Myle Ott committed
339
        return utils.move_to_cuda(sample)