trainer.py 14.2 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

"""
Myle Ott's avatar
Myle Ott committed
9
Train a network across multiple GPUs.
Myle Ott's avatar
Myle Ott committed
10
11
"""

Sergey Edunov's avatar
Sergey Edunov committed
12
from collections import defaultdict, OrderedDict
Myle Ott's avatar
Myle Ott committed
13
import contextlib
Sergey Edunov's avatar
Sergey Edunov committed
14
from itertools import chain
Myle Ott's avatar
Myle Ott committed
15

Myle Ott's avatar
Myle Ott committed
16
17
18
import torch

from fairseq import distributed_utils, optim, utils
Myle Ott's avatar
Myle Ott committed
19
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
Myle Ott's avatar
Myle Ott committed
20
21
22
23
from fairseq.optim import lr_scheduler


class Trainer(object):
Myle Ott's avatar
Myle Ott committed
24
    """Main class for data parallel training.
Myle Ott's avatar
Myle Ott committed
25

Myle Ott's avatar
Myle Ott committed
26
27
28
    This class supports data parallel training, where multiple workers each
    have a full model replica and gradients are accumulated synchronously via
    torch.distributed.all_reduce.
Myle Ott's avatar
Myle Ott committed
29
30
    """

Myle Ott's avatar
Myle Ott committed
31
    def __init__(self, args, task, model, criterion):
Myle Ott's avatar
Myle Ott committed
32
33
34
35
36
37
38

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args

        # copy model and criterion to current device
Myle Ott's avatar
Myle Ott committed
39
        self.task = task
Myle Ott's avatar
Myle Ott committed
40
41
42
43
44
45
46
47
48
        self.model = model.cuda()
        self.criterion = criterion.cuda()

        # initialize meters
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
Myle Ott's avatar
Myle Ott committed
49
50
51
52
53
54
55
        self.meters['wps'] = TimeMeter()       # words per second
        self.meters['ups'] = TimeMeter()       # updates per second
        self.meters['wpb'] = AverageMeter()    # words per batch
        self.meters['bsz'] = AverageMeter()    # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()   # % of updates clipped
        self.meters['oom'] = AverageMeter()    # out of memory
Myle Ott's avatar
Myle Ott committed
56
        self.meters['wall'] = TimeMeter()      # wall time in seconds
Myle Ott's avatar
Myle Ott committed
57
        self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
Myle Ott's avatar
Myle Ott committed
58

Sergey Edunov's avatar
Sergey Edunov committed
59
        self._buffered_stats = defaultdict(lambda: [])
Myle Ott's avatar
Myle Ott committed
60
        self._flat_grads = None
Myle Ott's avatar
Myle Ott committed
61
        self._num_updates = 0
alexeib's avatar
alexeib committed
62
        self._optim_history = None
Myle Ott's avatar
Myle Ott committed
63
64
65
66
67
68
69
        self._optimizer = None

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer
Myle Ott's avatar
Myle Ott committed
70

Myle Ott's avatar
Myle Ott committed
71
    def _build_optimizer(self):
Myle Ott's avatar
Myle Ott committed
72
73
        self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
        self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
Myle Ott's avatar
Myle Ott committed
74

Myle Ott's avatar
Myle Ott committed
75
76
    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
Myle Ott's avatar
Myle Ott committed
77
        if distributed_utils.is_master(self.args):  # only save one checkpoint
Myle Ott's avatar
Myle Ott committed
78
            extra_state['train_meters'] = self.meters
Myle Ott's avatar
Nits  
Myle Ott committed
79
80
81
82
            utils.save_state(
                filename, self.args, self.model, self.criterion, self.optimizer,
                self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
            )
Myle Ott's avatar
Myle Ott committed
83

84
    def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
Myle Ott's avatar
Myle Ott committed
85
        """Load all training state from a checkpoint file."""
86
        extra_state, self._optim_history, last_optim_state = \
Myle Ott's avatar
Myle Ott committed
87
            utils.load_model_state(filename, self.model)
Myle Ott's avatar
Myle Ott committed
88

89
        if last_optim_state is not None and not reset_optimizer:
Myle Ott's avatar
Myle Ott committed
90
            # rebuild optimizer after loading model, since params may have changed
Myle Ott's avatar
Myle Ott committed
91
            self._build_optimizer()
Myle Ott's avatar
Myle Ott committed
92

93
94
            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
Myle Ott's avatar
Myle Ott committed
95

96
97
98
99
100
101
102
103
104
105
106
107
            assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
                'criterion does not match; please reset the optimizer (--reset-optimizer)'

            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
                'optimizer does not match; please reset the optimizer (--reset-optimizer)'

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])

            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

            self._num_updates = last_optim['num_updates']
Myle Ott's avatar
Myle Ott committed
108

Myle Ott's avatar
Myle Ott committed
109
        if extra_state is not None and 'train_meters' in extra_state:
Myle Ott's avatar
Myle Ott committed
110
111
112
            self.meters = extra_state['train_meters']
            del extra_state['train_meters']

Myle Ott's avatar
Myle Ott committed
113
114
115
116
117
            # reset TimeMeters, since their start times don't make sense anymore
            for meter in self.meters.values():
                if isinstance(meter, TimeMeter):
                    meter.reset()

Myle Ott's avatar
Myle Ott committed
118
119
        return extra_state

Myle Ott's avatar
Myle Ott committed
120
    def train_step(self, sample, update_params=True, dummy_batch=False):
Myle Ott's avatar
Myle Ott committed
121
        """Do forward, backward and parameter update."""
Myle Ott's avatar
Myle Ott committed
122
123
124
125
126
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
Myle Ott's avatar
Myle Ott committed
127

Myle Ott's avatar
Myle Ott committed
128
129
130
        if not dummy_batch:
            self.meters['train_wall'].start()

Sergey Edunov's avatar
Sergey Edunov committed
131
        # forward and backward pass
Myle Ott's avatar
Myle Ott committed
132
        sample = self._prepare_sample(sample)
Sergey Edunov's avatar
Sergey Edunov committed
133
134
135
136
137
138
139
140
141
142
143
        loss, sample_size, logging_output, oom_fwd = self._forward(sample)
        oom_bwd = self._backward(loss)

        # buffer stats and logging outputs
        self._buffered_stats['sample_sizes'].append(sample_size)
        self._buffered_stats['logging_outputs'].append(logging_output)
        self._buffered_stats['ooms_fwd'].append(oom_fwd)
        self._buffered_stats['ooms_bwd'].append(oom_bwd)

        # update parameters
        if update_params:
Myle Ott's avatar
Myle Ott committed
144
            agg_logging_output = self._update_params()
Sergey Edunov's avatar
Sergey Edunov committed
145
        else:
Myle Ott's avatar
Myle Ott committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            agg_logging_output = None  # buffering updates

        if not dummy_batch:
            self.meters['train_wall'].stop()

        return agg_logging_output

    def _update_params(self):
        # gather logging outputs from all replicas
        sample_sizes = self._buffered_stats['sample_sizes']
        logging_outputs = self._buffered_stats['logging_outputs']
        ooms_fwd = self._buffered_stats['ooms_fwd']
        ooms_bwd = self._buffered_stats['ooms_bwd']
        if self.args.distributed_world_size > 1:
            sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
                lambda l: list(chain.from_iterable(l)),
                zip(*distributed_utils.all_gather_list(
                    (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
                ))
            )
        ooms_fwd = sum(ooms_fwd)
        ooms_bwd = sum(ooms_bwd)

        if ooms_fwd == self.args.distributed_world_size:
            print('| WARNING: OOM in all workers, skipping batch')
            self.zero_grad()
            return None

        # aggregate stats and logging outputs
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
        grad_denom = self.criterion.__class__.grad_denom(sample_sizes)

        try:
            # all-reduce and rescale gradients, then take an optimization step
            grad_norm = self._all_reduce_and_rescale(grad_denom)
            self._opt()

            # update meters
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            if grad_norm is not None:
                self.meters['gnorm'].update(grad_norm)
                self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
            self.meters['oom'].update(ooms_fwd + ooms_bwd)

            # update loss meters for training
            if 'loss' in agg_logging_output:
                self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
            # criterions can optionally log the NLL loss too
            if 'nll_loss' in agg_logging_output:
                self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
        except OverflowError as e:
            self.zero_grad()
            print('| WARNING: overflow detected, ' + str(e))

        self.clear_buffered_stats()

        return agg_logging_output
Myle Ott's avatar
Myle Ott committed
208
209
210
211
212
213
214
215
216

    def _forward(self, sample, eval=False):
        loss = None
        sample_size = 0
        logging_output = {
            'ntokens': sample['ntokens'] if sample is not None else 0,
            'nsentences': sample['target'].size(0) if sample is not None else 0,
        }
        oom = 0
Myle Ott's avatar
Myle Ott committed
217
218
219
220
221
222
223
224
        try:
            # prepare model and optimizer
            if eval:
                self.model.eval()
            else:
                self.model.train()

            if sample is not None:
Myle Ott's avatar
Myle Ott committed
225
                with torch.no_grad() if eval else contextlib.ExitStack():
Myle Ott's avatar
Myle Ott committed
226
                    # calculate loss and sample size
Myle Ott's avatar
Myle Ott committed
227
                    loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
Myle Ott's avatar
Myle Ott committed
228
                    logging_output.update(logging_output_)
Myle Ott's avatar
Myle Ott committed
229
230
231
232
233
234
235
        except RuntimeError as e:
            if not eval and 'out of memory' in str(e):
                print('| WARNING: ran out of memory, skipping batch')
                oom = 1
                loss = None
            else:
                raise e
Sergey Edunov's avatar
Sergey Edunov committed
236
        return loss, sample_size, logging_output, oom
Myle Ott's avatar
Myle Ott committed
237

Sergey Edunov's avatar
Sergey Edunov committed
238
    def _backward(self, loss):
Myle Ott's avatar
Myle Ott committed
239
240
241
242
243
244
245
246
247
        oom = 0
        if loss is not None:
            try:
                # backward pass
                loss.backward()
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom = 1
Myle Ott's avatar
Myle Ott committed
248
                    self.zero_grad()
Myle Ott's avatar
Myle Ott committed
249
250
                else:
                    raise e
Sergey Edunov's avatar
Sergey Edunov committed
251
        return oom
Myle Ott's avatar
Myle Ott committed
252

Myle Ott's avatar
Myle Ott committed
253
254
255
    def _all_reduce_and_rescale(self, grad_denom):
        # flatten grads into a single buffer and all-reduce
        flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
Myle Ott's avatar
Myle Ott committed
256
        if self.args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
257
            torch.distributed.all_reduce(flat_grads)
Myle Ott's avatar
Myle Ott committed
258

Myle Ott's avatar
Myle Ott committed
259
260
261
262
263
264
265
266
        # rescale and clip gradients
        flat_grads.div_(grad_denom)
        grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)

        # copy grads back into model parameters
        self._set_flat_grads(flat_grads)

        return grad_norm
Myle Ott's avatar
Myle Ott committed
267

Myle Ott's avatar
Myle Ott committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    def _get_grads(self):
        grads = []
        for name, p in self.model.named_parameters():
            if not p.requires_grad:
                continue
            if p.grad is None:
                raise RuntimeError('Model parameter did not receive gradient: ' + name + '. '
                                   'Use the param in the forward pass or set requires_grad=False')
            grads.append(p.grad.data)
        return grads

    def _get_flat_grads(self, out=None):
        grads = self._get_grads()
        if out is None:
            grads_size = sum(g.numel() for g in grads)
            out = grads[0].new(grads_size).zero_()
        offset = 0
        for g in grads:
            numel = g.numel()
            out[offset:offset+numel].copy_(g.view(-1))
            offset += numel
        return out[:offset]

    def _set_flat_grads(self, new_grads):
        grads = self._get_grads()
        offset = 0
        for g in grads:
            numel = g.numel()
            g.copy_(new_grads[offset:offset+numel].view_as(g))
            offset += numel

    def _opt(self):
Myle Ott's avatar
Myle Ott committed
300
301
        # take an optimization step
        self.optimizer.step()
Myle Ott's avatar
Myle Ott committed
302
        self.zero_grad()
Myle Ott's avatar
Myle Ott committed
303
304
305
306
307
308
309
310
        self._num_updates += 1

        # update learning rate
        self.lr_scheduler.step_update(self._num_updates)

    def valid_step(self, sample):
        """Do forward pass in evaluation mode."""
        # forward pass
Myle Ott's avatar
Myle Ott committed
311
        sample = self._prepare_sample(sample)
Sergey Edunov's avatar
Sergey Edunov committed
312
313
314
315
316
317
318
319
320
321
322
        _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
        assert not oom_fwd, 'Ran out of memory during validation'

        # gather logging outputs from all GPUs
        if self.args.distributed_world_size > 1:
            sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
                (sample_size, logging_output)
            ))
        else:
            sample_sizes = [sample_size]
            logging_outputs = [logging_output]
Myle Ott's avatar
Myle Ott committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

        # aggregate stats and logging outputs
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
        agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)

        # update loss meters for validation
        if 'loss' in agg_logging_output:
            self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom)
        # criterions can optionally log the NLL loss too
        if 'nll_loss' in agg_logging_output:
            self.meters['valid_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)

        return agg_logging_output

Myle Ott's avatar
Myle Ott committed
338
339
    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
Myle Ott's avatar
Myle Ott committed
340
        self.train_step(dummy_batch, update_params=False, dummy_batch=True)
Myle Ott's avatar
Myle Ott committed
341
342
343
344
345
346
347
348
349
        self.zero_grad()
        self.clear_buffered_stats()

    def zero_grad(self):
        self.optimizer.zero_grad()

    def clear_buffered_stats(self):
        self._buffered_stats.clear()

Myle Ott's avatar
Myle Ott committed
350
351
352
353
    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
        return self.lr_scheduler.step(epoch, val_loss)

Myle Ott's avatar
Myle Ott committed
354
355
356
357
    def lr_step_update(self, num_updates):
        """Update the learning rate after each update."""
        return self.lr_scheduler.step_update(num_updates)

Myle Ott's avatar
Myle Ott committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
        """Get the model replica."""
        return self.model

    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

Myle Ott's avatar
Myle Ott committed
376
    def _prepare_sample(self, sample):
Myle Ott's avatar
Myle Ott committed
377
378
        if sample is None or len(sample) == 0:
            return None
Myle Ott's avatar
Myle Ott committed
379
        return utils.move_to_cuda(sample)