test_lr_schedulers.py 17.2 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

import torch
import deepspeed
import pytest
from unit.common import DistributedTest
from unit.simple_model import SimpleModel, random_dataloader
from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, LR_RANGE_TEST_MIN_LR, LR_RANGE_TEST_STEP_RATE, LR_RANGE_TEST_STEP_SIZE, LR_RANGE_TEST_STAIRCASE
from deepspeed.runtime.lr_schedules import WARMUP_LR, WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS, WARMUP_TYPE, WARMUP_LOG_RATE, WARMUP_LINEAR_RATE
from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE
from deepspeed.runtime.lr_schedules import CYCLE_MIN_MOM, CYCLE_MAX_MOM, DECAY_MOM_RATE
from deepspeed.runtime.lr_schedules import WARMUP_DECAY_LR, TOTAL_NUM_STEPS


def _verify_continuous_decrease(values):
    for i in range(len(values) - 1):
        assert values[i] > values[i + 1]


def _verify_continuous_increase(values):
    for i in range(len(values) - 1):
        assert values[i] < values[i + 1]


def _verify_staircase_increase(values, step_size):
    num_values = len(values)
    for i in range(0, num_values, step_size):
        j = min(i + step_size, num_values)
        assert all([values[i] == v for v in values[i:j]])


aiss's avatar
aiss committed
35
36
37
38
39
40
41
42
@pytest.mark.parametrize("scheduler_type,params", [(WARMUP_LR, {}),
                                                   (WARMUP_DECAY_LR, {
                                                       WARMUP_NUM_STEPS: 10,
                                                       TOTAL_NUM_STEPS: 20
                                                   }), (ONE_CYCLE, {
                                                       CYCLE_MIN_LR: 0,
                                                       CYCLE_MAX_LR: 0.1
                                                   }), (LR_RANGE_TEST, {})])
aiss's avatar
aiss committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class TestGetLrBeforeTrain(DistributedTest):
    world_size = 1

    def test(self, scheduler_type, params):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                },
            },
            "scheduler": {
                "type": scheduler_type,
                "params": params
            },
            "gradient_clipping": 1.0
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim, empty_grad=False)
        model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
                                                         model=model,
                                                         model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)
        for n, batch in enumerate(data_loader):
            # get lr before training starts
            lr_scheduler.get_lr()
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()


@pytest.mark.parametrize("warmup_num_steps", [10, 15, 19, 33])
@pytest.mark.parametrize("warmup_type", [WARMUP_LOG_RATE, WARMUP_LINEAR_RATE])
class TestLrSchedule(DistributedTest):
    world_size = 1

    def test_lr_warmup_schedule(self, warmup_num_steps, warmup_type):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                },
            },
            "scheduler": {
                "type": WARMUP_LR,
                "params": {
                    WARMUP_MIN_LR: 0.1,
                    WARMUP_MAX_LR: 0.2,
                    WARMUP_NUM_STEPS: warmup_num_steps,
                    WARMUP_TYPE: warmup_type,
                }
            },
            "gradient_clipping": 1.0
        }
        schedule_params = config_dict["scheduler"]["params"]
        total_num_steps = 2 * warmup_num_steps
        hidden_dim = 10

        model = SimpleModel(hidden_dim, empty_grad=False)
        model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
                                                         model=model,
                                                         model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=total_num_steps * 2,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)
        step_lrs = []
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
            step_lrs.append(lr_scheduler.get_lr())

        # Verify initial lr
        assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]

        # Verify warmup completion
        warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
        warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
        assert step_lrs[warmup_num_steps] == warmup_max_lr

        # Verify post-warmup completion
        assert all([warmup_max_lr == lr for lr in step_lrs[warmup_num_steps:]])

    def test_lr_warmup_decay_schedule(self, warmup_num_steps, warmup_type):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                },
            },
            "scheduler": {
                "type": WARMUP_DECAY_LR,
                "params": {
                    WARMUP_MIN_LR: 0.1,
                    WARMUP_MAX_LR: 0.2,
                    WARMUP_NUM_STEPS: warmup_num_steps,
                    TOTAL_NUM_STEPS: warmup_num_steps * 2,
                    WARMUP_TYPE: warmup_type
                }
            },
            "gradient_clipping": 1.0
        }
        schedule_params = config_dict["scheduler"]["params"]
        total_num_steps = schedule_params[TOTAL_NUM_STEPS]
        hidden_dim = 10

        model = SimpleModel(hidden_dim, empty_grad=False)
        model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
                                                         model=model,
                                                         model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=total_num_steps * 2,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)
        step_lrs = []
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
            step_lrs.append(lr_scheduler.get_lr())

        # Verify initial lr
        assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]

        # Verify lr at warmup completion
        warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
        warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
        assert step_lrs[warmup_num_steps] == warmup_max_lr

        # Verify decay phase
        previous_lr = warmup_max_lr
        for lr in step_lrs[warmup_num_steps + 1:]:
            assert lr < previous_lr
            previous_lr = lr


aiss's avatar
aiss committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
@pytest.mark.parametrize("scheduler_type,params", [(WARMUP_LR, {}),
                                                   (WARMUP_DECAY_LR, {
                                                       WARMUP_NUM_STEPS: 5,
                                                       TOTAL_NUM_STEPS: 10
                                                   }),
                                                   (ONE_CYCLE, {
                                                       CYCLE_MIN_LR: 0,
                                                       CYCLE_MAX_LR: 0.1,
                                                       CYCLE_FIRST_STEP_SIZE: 5,
                                                       DECAY_STEP_SIZE: 5
                                                   }),
                                                   (LR_RANGE_TEST, {
                                                       LR_RANGE_TEST_MIN_LR: 1e-4,
                                                       LR_RANGE_TEST_STEP_SIZE: 1
                                                   })])
aiss's avatar
aiss committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
class TestSchedulerOptimizerParity(DistributedTest):
    world_size = 1

    def test(self, scheduler_type, params):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                },
            },
            "scheduler": {
                "type": scheduler_type,
                "params": params
            },
            "gradient_clipping": 1.0
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim, empty_grad=False)
        model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
                                                         model=model,
                                                         model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
            assert lr_scheduler.get_lr() == model.get_lr()


@pytest.mark.parametrize("min_lr, step_rate, step_size, staircase",
                         [(1e-4, 1e-5, 1, True),
                          (1e-5, 1e-5, 1, False),
                          (1e-4, 1e-3, 10, True),
                          (1e-3, 1e-3, 10, False),
                          (1e-2, 1e-2, 19, True),
                          (1e-2, 1e-2, 19, False)
                           ])# yapf: disable
class TestLrRange(DistributedTest):
    world_size = 1

    def test(self, min_lr, step_rate, step_size, staircase):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                },
            },
            "scheduler": {
                "type": LR_RANGE_TEST,
                "params": {
                    LR_RANGE_TEST_MIN_LR: min_lr,
                    LR_RANGE_TEST_STEP_RATE: step_rate,
                    LR_RANGE_TEST_STEP_SIZE: step_size,
                    LR_RANGE_TEST_STAIRCASE: staircase
                }
            },
            "gradient_clipping": 1.0
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim, empty_grad=False)
        model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
                                                         model=model,
                                                         model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
aiss's avatar
aiss committed
288
                                        total_samples=max(50, step_size * 2),
aiss's avatar
aiss committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)

        step_lrs = []
        for _, batch in enumerate(data_loader):
            step_lrs.extend(lr_scheduler.get_lr())
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

        # Verify starting lr
        assert step_lrs[0] == min_lr

        if staircase:
            # Verify staircase increasing lr
            _verify_staircase_increase(step_lrs, step_size)
        else:
            # Verify continuous increasing lr
            _verify_continuous_increase(step_lrs)


class TestOneCycle(DistributedTest):
    world_size = 1

    @pytest.mark.parametrize("min_lr, max_lr, decay_rate, cycle_step_size, decay_step_size",
                             [
                                 (1e-5, 1e-2, 1e-3, 10, 10),
                                 (1e-3, 1e-1, 0, 21, 21),
                                 (1e-5, 1e-2, 1e-3, 10, 10),
                                 (1e-3, 1e-1, 1e-1, 21, 21),
                                 (1e-5, 1e-1, 0, 10, 0),
                             ])  # yapf: disable
    def test_lr(self, min_lr, max_lr, decay_rate, cycle_step_size, decay_step_size):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                },
            },
            "scheduler": {
                "type": ONE_CYCLE,
                "params": {
                    CYCLE_MIN_LR: min_lr,
                    CYCLE_MAX_LR: max_lr,
                    DECAY_LR_RATE: decay_rate,
                    CYCLE_FIRST_STEP_SIZE: cycle_step_size,
                    DECAY_STEP_SIZE: decay_step_size
                }
            },
            "gradient_clipping": 1.0
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim, empty_grad=False)
        model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
                                                         model=model,
                                                         model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
aiss's avatar
aiss committed
351
                                        total_samples=max(50, cycle_step_size * 3),
aiss's avatar
aiss committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)

        step_lrs = []
        for _, batch in enumerate(data_loader):
            step_lrs.extend(lr_scheduler.get_lr())
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

        # Verify starting lr
        assert step_lrs[0] == min_lr

        # Verify peak lr
        assert step_lrs[cycle_step_size] == max_lr

        # Verify increasing phase
        _verify_continuous_increase(step_lrs[:cycle_step_size])

        # Verify decreasing phase
        _verify_continuous_decrease(step_lrs[cycle_step_size:(cycle_step_size * 2)])

        # Verify decay phase
        if decay_rate > 0:
            _verify_continuous_decrease(step_lrs[(cycle_step_size * 2):])

    @pytest.mark.parametrize("min_mom, max_mom, decay_rate, step_size",
                             [
                                 (0.08, 0.09, 1e-3, 10),
                                 (0.08, 0.09, 0, 21),
                                 (0.08, 0.09, 1e-3, 10),
                                 (0.08, 0.09, 0, 21),
                             ]) # yapf: disable
    def test_mom(self, min_mom, max_mom, decay_rate, step_size):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                },
            },
            "scheduler": {
                "type": ONE_CYCLE,
                "params": {
                    CYCLE_MIN_LR: 1e-3,
                    CYCLE_MAX_LR: 1e-2,
                    CYCLE_MIN_MOM: min_mom,
                    CYCLE_MAX_MOM: max_mom,
                    DECAY_MOM_RATE: decay_rate,
                    CYCLE_FIRST_STEP_SIZE: step_size,
                    DECAY_STEP_SIZE: step_size
                }
            },
            "gradient_clipping": 1.0
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim, empty_grad=False)
        model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
                                                         model=model,
                                                         model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
aiss's avatar
aiss committed
417
                                        total_samples=max(50, step_size * 3),
aiss's avatar
aiss committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)

        step_moms = []
        for _, batch in enumerate(data_loader):
            step_moms.append(lr_scheduler.get_mom())
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

        # Verify starting lr
        assert step_moms[0][0][0] == max_mom

        # Verify peak lr
        assert step_moms[step_size][0][0] == min_mom

        # Verify decreasing phase
        _verify_continuous_decrease(step_moms[:step_size])

        # Verify increasing phase
        _verify_continuous_increase(step_moms[step_size:(step_size * 2)])

        # Verify decay phase
        if decay_rate > 0:
            _verify_continuous_increase(step_moms[(step_size * 2):])