trainer.py 18.7 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

"""
Train a network across multiple GPUs.
"""

from collections import defaultdict, OrderedDict
import contextlib
from itertools import chain

import torch
import apex_C

from fairseq import distributed_utils, optim, utils
from fairseq.meters import AverageMeter, TimeMeter
from fairseq.optim import lr_scheduler


class Trainer(object):
    """Main class for data parallel training.

    This class supports data parallel training, where multiple workers each
    have a full model replica and gradients are accumulated synchronously via
    torch.distributed.all_reduce.
    """

    def __init__(self, args, task, model, criterion, allreduce_communicators=None):

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args

        # copy model and criterion to current device
        self.task = task
        self.model = model.cuda()
        self.criterion = criterion.cuda()

        # initialize meters
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
        self.meters['wps'] = TimeMeter()       # words per second
        self.meters['ups'] = TimeMeter()       # updates per second
        self.meters['wpb'] = AverageMeter()    # words per batch
        self.meters['bsz'] = AverageMeter()    # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()   # % of updates clipped
        self.meters['oom'] = AverageMeter()    # out of memory
        self.meters['wall'] = TimeMeter()      # wall time in seconds

        self._buffered_stats = defaultdict(lambda: [])
        self._flat_grads = None
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
        self._stats_allreduce_stream = torch.cuda.Stream()

        self._last_step = False
        if self.args.enable_parallel_backward_allred_opt and not self.args.distributed_world_size > 1:
            raise RuntimeError('--enable-parallel-backward-allred-opt is only meant for distributed training')
        if self.args.enable_parallel_backward_allred_opt and not self.args.fp16:
            raise RuntimeError('--enable-parallel-backward-allred-opt only works with FP16 training')

        # rework all_gather_list implementation to mitigate memcpy overheads
        # [sample_sizes,nsentences,loss,nll_loss,ooms_fwd,ooms_bwd]
        self._all_reduce_list = [0.0] * 6

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer

    def _build_optimizer(self):
        self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
        self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)

    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
        # FIXME: Gather optimizer state
        if distributed_utils.is_master(self.args):  # only save one checkpoint
            extra_state['train_meters'] = self.meters
            utils.save_state(
                filename, self.args, self.model, self.criterion, self.optimizer,
                self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
            )

    def load_checkpoint(self, filename, load_optim=True):
        """Load all training state from a checkpoint file."""
        # FIXME: Scatter optimizer state
        extra_state, optim_history, last_optim_state = \
            utils.load_model_state(filename, self.model)

        if last_optim_state is not None:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            if load_optim:
                self._optim_history = optim_history
                # only reload optimizer and lr_scheduler if they match
                last_optim = self._optim_history[-1]
                if last_optim['criterion_name'] == self.criterion.__class__.__name__:
                    self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
                    if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
                        self.optimizer.load_state_dict(last_optim_state)

                self._num_updates = last_optim['num_updates']

        if extra_state is not None and 'train_meters' in extra_state:
            self.meters = extra_state['train_meters']
            del extra_state['train_meters']

        return extra_state

    def train_step(self, sample, update_params=True, last_step=False):
        """Do forward, backward and parameter update."""
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        # INFO: Given we don't checkpoint, turning off setting the set.
        #seed = self.args.seed + self.get_num_updates()
        #torch.manual_seed(seed)
        #torch.cuda.manual_seed(seed)

        self._last_step = last_step
        if self.args.distributed_weight_update >= 2:
            self.optimizer.optimizer.set_last_step(self._last_step)

        # forward and backward pass
        sample = self._prepare_sample(sample)
        if sample is not None:
            my_ntokens = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        else:
            my_ntokens = 0
        with torch.cuda.stream(self._stats_allreduce_stream):
            global_ntokens = torch.full((1,), my_ntokens, dtype=torch.float32, device='cuda')
            if self.args.distributed_world_size > 1:
                torch.distributed.all_reduce(global_ntokens)
        loss, sample_size, logging_output, oom_fwd = self._forward(sample)
        torch.cuda.current_stream().wait_stream(self._stats_allreduce_stream)
        if self.args.distributed_weight_update >= 2:
            self.optimizer.optimizer.set_global_scale(global_ntokens[0]*self.scaler.loss_scale/torch.distributed.get_world_size())
        oom_bwd = self._backward(loss)

        # buffer stats and logging outputs
        self._buffered_stats['sample_sizes'].append(sample_size)
        self._buffered_stats['logging_outputs'].append(logging_output)
        self._buffered_stats['ooms_fwd'].append(oom_fwd)
        self._buffered_stats['ooms_bwd'].append(oom_bwd)

        # rework all_gather_list
        assert(sample_size == logging_output.get('sample_size', 0.0))
        assert(sample_size == logging_output.get('ntokens', 0.0))
        self._all_reduce_list[0] += sample_size
        self._all_reduce_list[1] += logging_output.get('nsentences', 0.0)
        self._all_reduce_list[2] += logging_output.get('loss', 0.0)
        self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0)
        self._all_reduce_list[4] += oom_fwd
        self._all_reduce_list[5] += oom_bwd

        # update parameters
        if update_params:
            check_against_old_code = False
            # check_against_old_code = True
            if check_against_old_code:
                # gather logging outputs from all replicas
                sample_sizes = self._buffered_stats['sample_sizes']
                logging_outputs = self._buffered_stats['logging_outputs']
                ooms_fwd = self._buffered_stats['ooms_fwd']
                ooms_bwd = self._buffered_stats['ooms_bwd']
                # print(sample_sizes,logging_outputs,ooms_fwd,ooms_bwd)
                if self.args.distributed_world_size > 1:
                    sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
                        lambda l: list(chain.from_iterable(l)),
                        zip(*distributed_utils.all_gather_list(
                            (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
                        ))
                    )
                # print("\n",sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)

                ooms_fwd = sum(ooms_fwd)
                ooms_bwd = sum(ooms_bwd)

                if ooms_fwd == self.args.distributed_world_size:
                    print('| WARNING: OOM in all workers, skipping batch')
                    self.zero_grad()
                    return None

                # aggregate stats and logging outputs
                ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
                nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
                agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
                grad_denom = self.criterion.__class__.grad_denom(sample_sizes)

                assert( grad_denom == sum(sample_sizes) )
                assert( grad_denom == ntokens )
                assert( grad_denom == agg_logging_output['sample_size'] )
                all_gather_list_tensor = torch.cuda.DoubleTensor([grad_denom, nsentences, agg_logging_output['loss'], agg_logging_output['nll_loss'], ooms_fwd, ooms_bwd])
                print("\n",all_gather_list_tensor)

            # rework all_gather_list
            all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
            if self.args.distributed_world_size > 1 and self.args.enable_global_stats:
                torch.distributed.all_reduce(all_reduce_list_tensor)
            # Skip `div` if distributed and not reducing stats
            if self.args.distributed_world_size == 1 or self.args.enable_global_stats:
                all_reduce_list_tensor[2:4].div_((all_reduce_list_tensor[0:1]*torch.log(torch.cuda.DoubleTensor([2]))))
            if check_against_old_code:
                print(all_reduce_list_tensor)
                assert(grad_denom == all_reduce_list_tensor[0].item())
                assert(nsentences == all_reduce_list_tensor[1].item())
                # compare loss values
                # print(all_gather_list_tensor[2:4] - all_reduce_list_tensor[2:4])
                assert(torch.all(torch.lt(torch.abs(torch.add(all_gather_list_tensor[2:4], -all_reduce_list_tensor[2:4])), 1e-12)))
                assert(ooms_fwd == all_reduce_list_tensor[4].item())
                assert(ooms_bwd == all_reduce_list_tensor[5].item())

            agg_logging_output = {}
            [grad_denom, nsentences, agg_logging_output['loss'], agg_logging_output['nll_loss'], ooms_fwd, ooms_bwd] = all_reduce_list_tensor.tolist()
            # `grad_denom` should be based on pre-allreduce, in case we skipped stats allreduce
            grad_denom = global_ntokens.item()
            agg_logging_output['sample_size'] = grad_denom
            ntokens = grad_denom
            # print([grad_denom, ntokens, nsentences, agg_logging_output['loss'], agg_logging_output['nll_loss'], ooms_fwd, ooms_bwd])
            if ooms_fwd == self.args.distributed_world_size:
                print('| WARNING: OOM in all workers, skipping batch')
                self.zero_grad()
                return None

            try:
                # all-reduce and rescale gradients, then take an optimization step
                grad_norm = self._all_reduce_and_rescale(grad_denom, sample is not None)
                self._opt()

                # update meters
                self.meters['wps'].update(ntokens)
                self.meters['ups'].update(1.)
                self.meters['wpb'].update(ntokens)
                self.meters['bsz'].update(nsentences)
                if grad_norm is not None:
                    self.meters['gnorm'].update(grad_norm)
                    self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
                self.meters['oom'].update(ooms_fwd + ooms_bwd)

                # update loss meters for training
                if 'loss' in agg_logging_output:
                    self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
                # criterions can optionally log the NLL loss too
                if 'nll_loss' in agg_logging_output:
                    self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
            except OverflowError as e:
                self.zero_grad()
                print('| WARNING: overflow detected, ' + str(e))

            self.clear_buffered_stats()

            return agg_logging_output
        else:
            return None  # buffering updates

    def _forward(self, sample, eval=False):
        loss = None
        sample_size = 0
        logging_output = {
            'ntokens': sample['ntokens'] if sample is not None else 0,
            'nsentences': sample['target'].size(0) if sample is not None else 0,
        }
        oom = 0
        try:
            # prepare model and optimizer
            if eval:
                self.model.eval()
            else:
                self.model.train()

            if sample is not None:
                with torch.no_grad() if eval else contextlib.ExitStack():
                    # calculate loss and sample size
                    loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
                    logging_output.update(logging_output_)
        except RuntimeError as e:
            if not eval and 'out of memory' in str(e):
                print('| WARNING: ran out of memory, skipping batch')
                oom = 1
                loss = None
            else:
                raise e
        return loss, sample_size, logging_output, oom

    def _backward(self, loss):
        oom = 0
        if loss is not None:
            try:
                # backward pass
                loss.backward()
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom = 1
                    self.zero_grad()
                else:
                    raise e
        return oom

    def _all_reduce_and_rescale(self, grad_denom, non_empty = True):
        # flatten grads into a single buffer and all-reduce
        flat_grads = self._flat_grads = self._get_flat_grads(out=self._flat_grads, has_grad = non_empty)
        if self.args.distributed_world_size > 1:
            torch.distributed.all_reduce(flat_grads)

        # rescale and clip gradients
        flat_grads.div_(grad_denom)
        grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)

        # copy grads back into model parameters
        self._set_flat_grads(flat_grads)

        return grad_norm

    def _get_grads(self, has_grad = True):
        grads = []
        for name, p in self.model.named_parameters():
            if not p.requires_grad:
                continue
            if p.grad is None:
                if has_grad:
                    raise RuntimeError('Model parameter did not receive gradient: ' + name + '. '
                                   'Use the param in the forward pass or set requires_grad=False')
                else:
                    p.grad = torch.zeros_like(p)
            grads.append(p.grad.data)
        return grads

    def _get_flat_grads(self, out=None, has_grad = True):
        grads = self._get_grads(has_grad)
        #if out is None:
        #    grads_size = sum(g.numel() for g in grads)
        #    out = grads[0].new(grads_size).zero_()
        #offset = 0
        #for g in grads:
        #    numel = g.numel()
        #    out[offset:offset+numel].copy_(g.view(-1))
        #    offset += numel
        #return out[:offset]
        return apex_C.flatten(grads)

    def _set_flat_grads(self, new_grads):
        grads = self._get_grads()
        offset = 0
        for g in grads:
            numel = g.numel()
            g.copy_(new_grads[offset:offset+numel].view_as(g))
            offset += numel

    def _opt(self):
        # take an optimization step
        self.optimizer.step()
        self.zero_grad()
        self._num_updates += 1

        # update learning rate
        self.lr_scheduler.step_update(self._num_updates)

    def valid_step(self, sample):
        """Do forward pass in evaluation mode."""
        # forward pass
        sample = self._prepare_sample(sample)
        _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
        assert not oom_fwd, 'Ran out of memory during validation'

        # gather logging outputs from all GPUs
        if self.args.distributed_world_size > 1:
            sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
                (sample_size, logging_output)
            ))
        else:
            sample_sizes = [sample_size]
            logging_outputs = [logging_output]

        # aggregate stats and logging outputs
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
        agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)

        # update loss meters for validation
        if 'loss' in agg_logging_output:
            self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom)
        # criterions can optionally log the NLL loss too
        if 'nll_loss' in agg_logging_output:
            self.meters['valid_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)

        return agg_logging_output

    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
        self.train_step(dummy_batch, update_params=False)
        self.zero_grad()
        self.clear_buffered_stats()

    def zero_grad(self):
        self.optimizer.zero_grad()

    def clear_buffered_stats(self):
        self._buffered_stats.clear()
        self._all_reduce_list = [0.0] * 6

    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
        return self.lr_scheduler.step(epoch, val_loss)

    def lr_step_update(self, num_updates):
        """Update the learning rate after each update."""
        return self.lr_scheduler.step_update(num_updates)

    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
        """Get the model replica."""
        return self.model

    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

    def _prepare_sample(self, sample):
        if sample is None or len(sample) == 0:
            return None
        return utils.move_to_cuda(sample)