test_lr_schedulers.py 7.44 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import deepspeed
import argparse
import pytest
import json
import os
from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
9
10
11
from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR
from deepspeed.runtime.lr_schedules import WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS, TOTAL_NUM_STEPS
from deepspeed.runtime.lr_schedules import CYCLE_MIN_LR, CYCLE_MAX_LR
12
13
14


@pytest.mark.parametrize("scheduler_type,params",
15
                         [(WARMUP_LR,
16
                           {}),
17
                          (WARMUP_DECAY_LR,
18
                           {
19
20
                               WARMUP_NUM_STEPS: 10,
                               TOTAL_NUM_STEPS: 20
21
                           }),
22
23
24
25
26
27
                          (ONE_CYCLE,
                           {
                               CYCLE_MIN_LR: 0,
                               CYCLE_MAX_LR: 0
                           }),
                          (LR_RANGE_TEST,
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
                           {})])
def test_get_lr_before_train(tmpdir, scheduler_type, params):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 0.00015
            },
        },
        "scheduler": {
            "type": scheduler_type,
            "params": params
        },
        "gradient_clipping": 1.0
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[1])
    def _test_get_lr_before_train(args, model, hidden_dim):
        model, _, _, lr_scheduler = deepspeed.initialize(args=args,
53
54
                                                         model=model,
                                                         model_parameters=model.parameters())
55
56
57
58
59
60
61
62
63
64
65
66
67
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)
        for n, batch in enumerate(data_loader):
            # get lr before training starts
            lr_scheduler.get_lr()
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

    _test_get_lr_before_train(args=args, model=model, hidden_dim=hidden_dim)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207


@pytest.mark.parametrize("warmup_num_steps", [10, 15, 19, 33])
def test_lr_warmup_schedule(tmpdir, warmup_num_steps):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 0.00015
            },
        },
        "scheduler": {
            "type": WARMUP_LR,
            "params": {
                WARMUP_MIN_LR: 0.1,
                WARMUP_MAX_LR: 0.2,
                WARMUP_NUM_STEPS: warmup_num_steps
            }
        },
        "gradient_clipping": 1.0
    }

    total_num_steps = 2 * warmup_num_steps

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[1])
    def _test_lr_warmup_schedule(args, model, hidden_dim, schedule_params, num_steps):
        model, _, _, lr_scheduler = deepspeed.initialize(args=args,
                                                         model=model,
                                                         model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=num_steps * 2,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)
        step_lrs = []
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
            step_lrs.append(lr_scheduler.get_lr())

        # Verify initial lr
        assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]

        # Verify warmup completion
        warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
        warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
        assert step_lrs[warmup_num_steps] == warmup_max_lr

        # Verify post-warmup completion
        assert all([warmup_max_lr == lr for lr in step_lrs[warmup_num_steps:]])

    _test_lr_warmup_schedule(args=args,
                             model=model,
                             hidden_dim=hidden_dim,
                             schedule_params=config_dict["scheduler"]["params"],
                             num_steps=total_num_steps)


@pytest.mark.parametrize("warmup_num_steps", [10, 15, 19, 33])
def test_lr_warmup_decay_schedule(tmpdir, warmup_num_steps):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 0.00015
            },
        },
        "scheduler": {
            "type": WARMUP_DECAY_LR,
            "params": {
                WARMUP_MIN_LR: 0.1,
                WARMUP_MAX_LR: 0.2,
                WARMUP_NUM_STEPS: warmup_num_steps,
                TOTAL_NUM_STEPS: warmup_num_steps * 2
            }
        },
        "gradient_clipping": 1.0
    }

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[1])
    def _test_lr_warmup_decay_schedule(args,
                                       model,
                                       hidden_dim,
                                       schedule_params,
                                       num_steps):
        model, _, _, lr_scheduler = deepspeed.initialize(args=args,
                                                         model=model,
                                                         model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=num_steps * 2,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)
        step_lrs = []
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
            step_lrs.append(lr_scheduler.get_lr())

        # Verify initial lr
        assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]

        # Verify lr at warmup completion
        warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
        warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
        assert step_lrs[warmup_num_steps] == warmup_max_lr

        # Verify decay phase
        previous_lr = warmup_max_lr
        for lr in step_lrs[warmup_num_steps + 1:]:
            assert lr < previous_lr
            previous_lr = lr

    schedule_params = config_dict["scheduler"]["params"]

    total_num_steps = schedule_params[TOTAL_NUM_STEPS]

    _test_lr_warmup_decay_schedule(args=args,
                                   model=model,
                                   hidden_dim=hidden_dim,
                                   schedule_params=schedule_params,
                                   num_steps=total_num_steps)