test_checkpointing.py 11.5 KB
Newer Older
ptrblck's avatar
ptrblck committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from apex import amp

from utils import common_init, FLOAT


class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(6)
        self.param = nn.Parameter(torch.randn(1))

    def forward(self, x):
        x = x * self.param
        x = F.relu(self.conv1(x))
        x = self.bn1(x)
        return x


class TestCheckpointing(unittest.TestCase):
    def setUp(self):
        self.initial_lr = 1e-3
rohithkrn's avatar
rohithkrn committed
30
        self.test_opt_levels = ("O0", "O1", "O2", "O3", "O4", "O5")
ptrblck's avatar
ptrblck committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    def seed(self):
        torch.manual_seed(2809)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    def check_state_dict_fp32(self, state_dict):
        for key in state_dict:
            if 'num_batches_tracked' in key:
                continue
            param = state_dict[key]
            self.assertEqual(param.type(), FLOAT,
                             'Parameter in state_dict not FLOAT')

    def train_step(self, model, optimizer, data, loss_ids):
46
        optimizer.zero_grad()
ptrblck's avatar
ptrblck committed
47
48
49
50
51
52
53
54
55
56
57
58

        output = model(data)

        # Call backward for num_losses-1
        for idx in loss_ids:
            loss = output.mean()
            with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
                scaled_loss.backward(retain_graph=True)

        optimizer.step()
        return output

59
    def compare_models(self, modelA, modelB, test_setup=''):
ptrblck's avatar
ptrblck committed
60
61
62
        state_dictA = modelA.state_dict()
        state_dictB = modelB.state_dict()
        self.assertEqual(len(state_dictA), len(state_dictB),
63
                         'state_dicts have different lengths' + test_setup)
ptrblck's avatar
ptrblck committed
64
65
66
        for key in state_dictA:
            paramA = state_dictA[key]
            paramB = state_dictB[key]
67
68
69
70
            self.assertTrue((paramA==paramB).all(),
                msg='Parameters in state_dices not equal.' +
                    'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format(
                        key, paramA, paramB, paramA - paramB, test_setup))
ptrblck's avatar
ptrblck committed
71
72
73
74
75
76
77
78

    def test_restoring(self):
        nb_epochs = 10
        nb_epochs_restore = nb_epochs // 2
        for opt_level in self.test_opt_levels:
            for res_opt_level in self.test_opt_levels:
                for amp_before_load in [True, False]:
                    for num_losses in range(1, 3):
79
80
81
82
83
                        test_setup = ('#' * 75 + '\n' + \
                              f'opt_level {opt_level}\n' + \
                              f'restore_opt_level {res_opt_level}\n' + \
                              f'amp_before_load {amp_before_load}\n' + \
                              f'num_losses {num_losses}\n')
ptrblck's avatar
ptrblck committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

                        self.seed()

                        # Create reference model
                        model = MyModel().to('cuda')

                        optimizer = optim.SGD(model.parameters(),
                                              lr=self.initial_lr)

                        # Initialize with num_losses*2 for the original model and the restored one
                        model, optimizer = amp.initialize(
                            model, optimizer, opt_level=opt_level,
                            num_losses=num_losses*2, verbosity=0)

                        # Compare training behavior for same restore option
                        # We cannot really generalize it, since a saved model in O0
                        # would introduce a skipped step in O1, which will raise an error
                        if opt_level == res_opt_level:
                            # train for nb_epochs and restore after nb_epochs_restore
                            for epoch in range(nb_epochs):
104

ptrblck's avatar
ptrblck committed
105
106
107
108
                                x = torch.randn(16, 3, 24, 24, device='cuda')
                                output = self.train_step(
                                    model, optimizer, x, range(num_losses))
                                # Initialize model one step before comparing.
109
                                # Otherwise the batchnorm layers will be updated
ptrblck's avatar
ptrblck committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
                                # additionally in restore_model
                                if epoch == (nb_epochs_restore - 1):
                                    # Load model and optimizer
                                    checkpoint = {
                                        'model': model.state_dict(),
                                        'optimizer': optimizer.state_dict(),
                                        'amp': amp.state_dict()
                                    }
                                    # Check state_dict for FP32 tensors
                                    self.check_state_dict_fp32(checkpoint['model'])

                                    # Restore model
                                    restore_model = MyModel().to('cuda')
                                    restore_optimizer = optim.SGD(
                                        restore_model.parameters(),
                                        lr=self.initial_lr)

                                    if amp_before_load:
                                        restore_model, restore_optimizer = amp.initialize(
                                            restore_model,
                                            restore_optimizer,
                                            opt_level=res_opt_level,
                                            num_losses=num_losses*2,
                                            verbosity=0)

                                    restore_model.load_state_dict(checkpoint['model'])
                                    restore_optimizer.load_state_dict(checkpoint['optimizer'])
                                    # FIXME: We cannot test the amp.state_dict in the same script
                                    # amp.load_state_dict(checkpoint['amp'])

                                    if not amp_before_load:
                                        restore_model, restore_optimizer = amp.initialize(
                                            restore_model,
                                            restore_optimizer,
                                            opt_level=res_opt_level,
                                            num_losses=num_losses*2,
                                            verbosity=0)

                                elif epoch >= nb_epochs_restore:
                                    restore_output = self.train_step(
                                        restore_model,
                                        restore_optimizer,
                                        x,
                                        range(num_losses, num_losses*2))
                                    self.assertTrue(
                                        torch.allclose(output.float(), restore_output.float()),
156
157
                                        'Output of reference and restored models differ for ' + test_setup)
                                    self.compare_models(model, restore_model, test_setup)
ptrblck's avatar
ptrblck committed
158
159
                        # if opt_level != res_opt_level
                        else:
160
161
                            # skip tests for different opt_levels
                            continue
ptrblck's avatar
ptrblck committed
162
163
164
165
166
167
168
169
170
171

    def test_loss_scale_decrease(self):
        num_losses = 3
        nb_decrease_loss_scales = [0, 1, 2]
        for opt_level in self.test_opt_levels:
            #print('#' * 75 + f'\n opt_level {opt_level}\n')
            # Create new tmp copy for this run
            nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales)

            model = MyModel().to('cuda')
172

ptrblck's avatar
ptrblck committed
173
            optimizer = optim.SGD(model.parameters(),
174
                                  lr=self.initial_lr)
175

ptrblck's avatar
ptrblck committed
176
177
178
179
180
181
182
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=opt_level, num_losses=num_losses,
                verbosity=0)

            if amp._amp_state.opt_properties.loss_scale != 'dynamic':
                #print('Static loss scale set. Skipping opt_level.')
                continue
183

ptrblck's avatar
ptrblck committed
184
185
186
187
188
            # force to skip some updates to decrease the loss_scale
            initial_loss_scales = []
            for idx in range(num_losses):
                initial_loss_scales.append(
                    amp._amp_state.loss_scalers[idx].loss_scale())
189

ptrblck's avatar
ptrblck committed
190
191
192
193
194
195
            for _ in range(len(nb_decrease_loss_scales)):
                x = torch.randn(16, 3, 24, 24, device='cuda')
                for idx in range(num_losses):
                    while nb_decrease_loss_scales_tmp[idx] > 0:
                        optimizer.zero_grad()
                        output = model(x * 2**17)
196
197
                        loss = output.mean()

ptrblck's avatar
ptrblck committed
198
199
200
201
                        with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
                            scaled_loss.backward(retain_graph=True)
                        optimizer.step()
                        nb_decrease_loss_scales_tmp[idx] -= 1
202

ptrblck's avatar
ptrblck committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            # Check loss scales afterwards
            updated_loss_scales = []
            for idx in range(num_losses):
                updated_loss_scales.append(
                    amp._amp_state.loss_scalers[idx].loss_scale())
            for factor, update_ls, init_ls in zip(nb_decrease_loss_scales,
                                                  updated_loss_scales,
                                                  initial_loss_scales):
                self.assertEqual(update_ls, init_ls / 2**factor)

            # Check state dict
            amp_state_dict = amp.state_dict()
            for scaler_idx, factor, init_ls in zip(amp_state_dict,
                                                   nb_decrease_loss_scales,
                                                   initial_loss_scales):
                scaler = amp_state_dict[scaler_idx]
                self.assertEqual(scaler['loss_scale'], init_ls / 2**factor)
                unskipped_target = 0
                self.assertEqual(scaler['unskipped'], unskipped_target)

    def test_state_dict(self):
        for opt_level in self.test_opt_levels:
            # Skip O3
            if opt_level == 'O3':
                continue

            model = MyModel().to('cuda')
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=opt_level, verbosity=0)

            # Export state_dict and check for Half
            state_dict = model.state_dict()
            for key in state_dict:
                self.assertFalse('Half' in state_dict[key].type())
rohithkrn's avatar
rohithkrn committed
238
                self.assertFalse('BFloat16' in state_dict[key].type())
ptrblck's avatar
ptrblck committed
239
240
241
242
243

            # Check, if model is still trainable
            # Create dummy data
            data = torch.randn(10, 3, 4, 4, device='cuda')
            target = torch.randn(10, 6, 4, 4, device='cuda')
244

ptrblck's avatar
ptrblck committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            # Get initnial loss
            optimizer.zero_grad()
            output = model(data)
            loss = F.mse_loss(output, target)
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            last_loss = loss.item()

            # train for some epochs
            for epoch in range(10):
                optimizer.zero_grad()
                output = model(data)
                loss = F.mse_loss(output, target)
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
                self.assertTrue(loss.item() < last_loss)
                last_loss = loss.item()

if __name__=='__main__':
    unittest.main()
267