trainer.py 16.7 KB
Newer Older
Pan,Huiwen's avatar
Pan,Huiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright (c) 2017 Elad Hoffer
# Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

huchen's avatar
huchen committed
22
23
24
25
26
27
28
29
30
import logging
import os
import time
from itertools import cycle

import numpy as np
import torch
import torch.optim
import torch.utils.data
Pan,Huiwen's avatar
Pan,Huiwen committed
31
from apex.parallel import DistributedDataParallel
huchen's avatar
huchen committed
32
33
from apex import amp

Pan,Huiwen's avatar
Pan,Huiwen committed
34
35
36
37
from seq2seq.train.fp_optimizers import FP16Optimizer
from seq2seq.train.fp_optimizers import FP32Optimizer
from seq2seq.train.fp_optimizers import AMPOptimizer
from seq2seq.train.lr_scheduler import WarmupMultiStepLR
huchen's avatar
huchen committed
38
39
40
41
42
43
44
45
46
47
48
49
from seq2seq.utils import AverageMeter
from seq2seq.utils import sync_workers


class Seq2SeqTrainer:
    """
    Seq2SeqTrainer
    """
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
Pan,Huiwen's avatar
Pan,Huiwen committed
50
                 scheduler_config,
huchen's avatar
huchen committed
51
52
53
54
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 save_info={},
Pan,Huiwen's avatar
Pan,Huiwen committed
55
                 save_dir='.',
huchen's avatar
huchen committed
56
57
58
59
60
61
62
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 loss_scaling={},
                 intra_epoch_eval=0,
                 prealloc_mode='always',
Pan,Huiwen's avatar
Pan,Huiwen committed
63
                 warmup=0,
huchen's avatar
huchen committed
64
                 iter_size=1,
Pan,Huiwen's avatar
Pan,Huiwen committed
65
66
                 translator=None,
                 verbose=False):
huchen's avatar
huchen committed
67
68
69
70
71
72
        """
        Constructor for the Seq2SeqTrainer.

        :param model: model to train
        :param criterion: criterion (loss function)
        :param opt_config: dictionary with options for the optimizer
Pan,Huiwen's avatar
Pan,Huiwen committed
73
74
        :param scheduler_config: dictionary with options for the learning rate
            scheduler
huchen's avatar
huchen committed
75
76
77
78
        :param print_freq: prints short summary every 'print_freq' iterations
        :param save_freq: saves checkpoint every 'save_freq' iterations
        :param grad_clip: coefficient for gradient clipping
        :param save_info: dict with additional state stored in each checkpoint
Pan,Huiwen's avatar
Pan,Huiwen committed
79
        :param save_dir: path to the directiory for checkpoints
huchen's avatar
huchen committed
80
81
82
83
84
85
86
87
88
        :param train_iterations: total number of training iterations to execute
        :param checkpoint_filename: name of files with checkpoints
        :param keep_checkpoints: max number of checkpoints to keep
        :param math: arithmetic type
        :param loss_scaling: options for dynamic loss scaling
        :param intra_epoch_eval: number of additional eval runs within each
            training epoch
        :param prealloc_mode: controls preallocation,
            choices=['off', 'once', 'always']
Pan,Huiwen's avatar
Pan,Huiwen committed
89
        :param warmup: number of warmup iterations for performance counters
huchen's avatar
huchen committed
90
        :param iter_size: number of iterations between weight updates
Pan,Huiwen's avatar
Pan,Huiwen committed
91
        :param translator: instance of Translator, runs inference on test set
huchen's avatar
huchen committed
92
93
94
95
96
97
98
        :param verbose: enables verbose logging
        """
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
Pan,Huiwen's avatar
Pan,Huiwen committed
99
        self.save_dir = save_dir
huchen's avatar
huchen committed
100
101
102
103
104
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
Pan,Huiwen's avatar
Pan,Huiwen committed
105
        self.device = next(model.parameters()).device
huchen's avatar
huchen committed
106
107
108
        self.print_freq = print_freq
        self.verbose = verbose
        self.loss = None
Pan,Huiwen's avatar
Pan,Huiwen committed
109
        self.translator = translator
huchen's avatar
huchen committed
110
        self.intra_epoch_eval = intra_epoch_eval
Pan,Huiwen's avatar
Pan,Huiwen committed
111
        self.warmup = warmup
huchen's avatar
huchen committed
112
113
114
115
        self.iter_size = iter_size
        self.prealloc_mode = prealloc_mode
        self.preallocated = False

Pan,Huiwen's avatar
Pan,Huiwen committed
116
117
        self.distributed = torch.distributed.is_initialized()
        self.batch_first = model.batch_first
huchen's avatar
huchen committed
118
119

        params = self.model.parameters()
Pan,Huiwen's avatar
Pan,Huiwen committed
120
121
122
123
124
125
126
127
128
129

        if math == 'manual_fp16':
            self.fp_optimizer = FP16Optimizer(
                self.model, grad_clip,
                loss_scale=loss_scaling['init_scale'],
                dls_upscale_interval=loss_scaling['upscale_interval']
                )
            params = self.fp_optimizer.fp32_params
        elif math == 'fp32' or math == 'tf32':
            self.fp_optimizer = FP32Optimizer(self.model, grad_clip)
huchen's avatar
huchen committed
130
131

        opt_name = opt_config.pop('optimizer')
Pan,Huiwen's avatar
Pan,Huiwen committed
132
        self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config)
huchen's avatar
huchen committed
133
134
        logging.info(f'Using optimizer: {self.optimizer}')

Pan,Huiwen's avatar
Pan,Huiwen committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations,
                                           **scheduler_config)

        if math == 'fp16':
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                cast_model_outputs=torch.float16,
                keep_batchnorm_fp32=False,
                opt_level='O2')

            self.fp_optimizer = AMPOptimizer(
                self.model,
                grad_clip,
                loss_scale=loss_scaling['init_scale'],
                dls_upscale_interval=loss_scaling['upscale_interval']
                )

        if self.distributed:
            self.model = DistributedDataParallel(self.model)
huchen's avatar
huchen committed
155
156
157
158
159
160
161
162
163
164
165
166

    def iterate(self, src, tgt, update=True, training=True):
        """
        Performs one iteration of the training/validation.

        :param src: batch of examples from the source language
        :param tgt: batch of examples from the target language
        :param update: if True: optimizer does update of the weights
        :param training: if True: executes optimizer
        """
        src, src_length = src
        tgt, tgt_length = tgt
Pan,Huiwen's avatar
Pan,Huiwen committed
167
168
169
        src = src.to(self.device)
        tgt = tgt.to(self.device)
        src_length = src_length.to(self.device)
huchen's avatar
huchen committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

        num_toks = {}
        num_toks['tgt'] = int(sum(tgt_length - 1))
        num_toks['src'] = int(sum(src_length))

        if self.batch_first:
            output = self.model(src, src_length, tgt[:, :-1])
            tgt_labels = tgt[:, 1:]
            T, B = output.size(1), output.size(0)
        else:
            output = self.model(src, src_length, tgt[:-1])
            tgt_labels = tgt[1:]
            T, B = output.size(0), output.size(1)

        loss = self.criterion(output.view(T * B, -1),
                              tgt_labels.contiguous().view(-1))

Pan,Huiwen's avatar
Pan,Huiwen committed
187
        loss_per_batch = loss.item()
huchen's avatar
huchen committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        loss /= (B * self.iter_size)

        if training:
            self.fp_optimizer.step(loss, self.optimizer, self.scheduler,
                                   update)

        loss_per_token = loss_per_batch / num_toks['tgt']
        loss_per_sentence = loss_per_batch / B

        return loss_per_token, loss_per_sentence, num_toks

    def feed_data(self, data_loader, training=True):
        """
        Runs training or validation on batches from data_loader.

        :param data_loader: data loader
        :param training: if True runs training else runs validation
        """
        if training:
            assert self.optimizer is not None
            eval_fractions = np.linspace(0, 1, self.intra_epoch_eval+2)[1:-1]
            iters_with_update = len(data_loader) // self.iter_size
            eval_iters = (eval_fractions * iters_with_update).astype(int)
            eval_iters = eval_iters * self.iter_size
            eval_iters = set(eval_iters)

Pan,Huiwen's avatar
Pan,Huiwen committed
214
215
216
217
        batch_time = AverageMeter(self.warmup)
        data_time = AverageMeter(self.warmup)
        losses_per_token = AverageMeter()
        losses_per_sentence = AverageMeter()
huchen's avatar
huchen committed
218

Pan,Huiwen's avatar
Pan,Huiwen committed
219
220
221
        tot_tok_time = AverageMeter(self.warmup)
        src_tok_time = AverageMeter(self.warmup)
        tgt_tok_time = AverageMeter(self.warmup)
huchen's avatar
huchen committed
222
223
224
225
226
227
228
229
230
231
232
233

        batch_size = data_loader.batch_size

        end = time.time()
        for i, (src, tgt) in enumerate(data_loader):
            self.save_counter += 1
            # measure data loading time
            data_time.update(time.time() - end)

            update = False
            if i % self.iter_size == self.iter_size - 1:
                update = True
Pan,Huiwen's avatar
Pan,Huiwen committed
234

huchen's avatar
huchen committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
            # do a train/evaluate iteration
            stats = self.iterate(src, tgt, update, training=training)
            loss_per_token, loss_per_sentence, num_toks = stats

            # measure accuracy and record loss
            losses_per_token.update(loss_per_token, num_toks['tgt'])
            losses_per_sentence.update(loss_per_sentence, batch_size)

            # measure elapsed time
            elapsed = time.time() - end
            batch_time.update(elapsed)
            src_tok_time.update(num_toks['src'] / elapsed)
            tgt_tok_time.update(num_toks['tgt'] / elapsed)
            tot_num_toks = num_toks['tgt'] + num_toks['src']
            tot_tok_time.update(tot_num_toks / elapsed)
            self.loss = losses_per_token.avg
Pan,Huiwen's avatar
Pan,Huiwen committed
251

huchen's avatar
huchen committed
252
            if training and i in eval_iters:
Pan,Huiwen's avatar
Pan,Huiwen committed
253
254
255
256
257
258
259
260
261
                eval_fname = f'eval_epoch_{self.epoch}_iter_{i}'
                eval_path = os.path.join(self.save_dir, eval_fname)
                _, eval_stats = self.translator.run(
                    calc_bleu=True,
                    epoch=self.epoch,
                    iteration=i,
                    eval_path=eval_path,
                    )
                test_bleu = eval_stats['bleu']
huchen's avatar
huchen committed
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321

                log = []
                log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'BLEU: {test_bleu:.2f}']
                log = '\t'.join(log)
                logging.info(log)

                self.model.train()
                self.preallocate(data_loader.batch_size,
                                 data_loader.dataset.max_len, training=True)

            if i % self.print_freq == 0:
                phase = 'TRAIN' if training else 'VALIDATION'
                log = []
                log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})']
                log += [f'Data {data_time.val:.2e} ({data_time.avg:.2e})']
                log += [f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})']
                if self.verbose:
                    log += [f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})']
                    log += [f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})']
                    log += [f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})']
                log += [f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})']
                if training:
                    lr = self.optimizer.param_groups[0]['lr']
                    log += [f'LR {lr:.3e}']
                log = '\t'.join(log)
                logging.info(log)

            save_chkpt = (self.save_counter % self.save_freq) == (self.save_freq - 1)
            if training and save_chkpt:
                self.save_counter = 0
                self.save_info['iteration'] = i
                identifier = next(self.checkpoint_counter, -1)
                if identifier != -1:
                    with sync_workers() as rank:
                        if rank == 0:
                            self.save(identifier=identifier)

            end = time.time()

        tot_tok_time.reduce('sum')
        losses_per_token.reduce('mean')

        return losses_per_token.avg, tot_tok_time.avg

    def preallocate(self, batch_size, max_length, training):
        """
        Generates maximum sequence length batch and runs forward and backward
        pass without updating model parameters.

        :param batch_size: batch size for preallocation
        :param max_length: max sequence length for preallocation
        :param training: if True preallocates memory for backward pass
        """
        if self.prealloc_mode == 'always' or (self.prealloc_mode == 'once' and
                                              not self.preallocated):
            logging.info('Executing preallocation')
            torch.cuda.empty_cache()

Pan,Huiwen's avatar
Pan,Huiwen committed
322
323
324
325
            src_length = torch.full((batch_size,), max_length,
                                    dtype=torch.int64)
            tgt_length = torch.full((batch_size,), max_length,
                                    dtype=torch.int64)
huchen's avatar
huchen committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405

            if self.batch_first:
                shape = (batch_size, max_length)
            else:
                shape = (max_length, batch_size)

            src = torch.full(shape, 4, dtype=torch.int64)
            tgt = torch.full(shape, 4, dtype=torch.int64)
            src = src, src_length
            tgt = tgt, tgt_length
            self.iterate(src, tgt, update=False, training=training)
            self.model.zero_grad()
            self.preallocated = True

    def optimize(self, data_loader):
        """
        Sets model in training mode, preallocates memory and runs training on
        data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(True)
        self.model.train()
        self.preallocate(data_loader.batch_size, data_loader.dataset.max_len,
                         training=True)

        output = self.feed_data(data_loader, training=True)

        self.model.zero_grad()
        return output

    def evaluate(self, data_loader):
        """
        Sets model in eval mode, disables gradients, preallocates memory and
        runs validation on data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(False)
        self.model.eval()
        self.preallocate(data_loader.batch_size, data_loader.dataset.max_len,
                         training=False)

        output = self.feed_data(data_loader, training=False)

        self.model.zero_grad()
        return output

    def load(self, filename):
        """
        Loads checkpoint from filename.

        :param filename: path to the checkpoint file
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'})
            if self.distributed:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.fp_optimizer.initialize_model(self.model)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epoch = checkpoint['epoch']
            self.loss = checkpoint['loss']
            logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})')
        else:
            logging.error(f'Invalid checkpoint: {filename}')

    def save(self, identifier=None, is_best=False, save_all=False):
        """
        Stores checkpoint to a file.

        :param identifier: identifier for periodic checkpoint
        :param is_best: if True stores checkpoint to 'model_best.pth'
        :param save_all: if True stores checkpoint after completed training
            epoch
        """

        def write_checkpoint(state, filename):
Pan,Huiwen's avatar
Pan,Huiwen committed
406
            filename = os.path.join(self.save_dir, filename)
huchen's avatar
huchen committed
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
            logging.info(f'Saving model to {filename}')
            torch.save(state, filename)

        if self.distributed:
            model_state = self.model.module.state_dict()
        else:
            model_state = self.model.state_dict()

        state = {
            'epoch': self.epoch,
            'state_dict': model_state,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'loss': getattr(self, 'loss', None),
        }
        state = dict(list(state.items()) + list(self.save_info.items()))

        if identifier is not None:
            filename = self.checkpoint_filename % identifier
            write_checkpoint(state, filename)

        if is_best:
            filename = 'model_best.pth'
            write_checkpoint(state, filename)

        if save_all:
            filename = f'checkpoint_epoch_{self.epoch:03d}.pth'
            write_checkpoint(state, filename)