test_optimization.py 7.78 KB
Newer Older
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16

17
import os
18
import tempfile
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import unittest
20

21
from transformers import is_torch_available
22
from transformers.testing_utils import require_torch
Aymeric Augustin's avatar
Aymeric Augustin committed
23
24


25
if is_torch_available():
thomwolf's avatar
thomwolf committed
26
    import torch
27
    from torch import nn
thomwolf's avatar
thomwolf committed
28

29
    from transformers import (
30
        Adafactor,
31
32
33
34
35
        AdamW,
        get_constant_schedule,
        get_constant_schedule_with_warmup,
        get_cosine_schedule_with_warmup,
        get_cosine_with_hard_restarts_schedule_with_warmup,
36
        get_inverse_sqrt_schedule,
37
        get_linear_schedule_with_warmup,
38
        get_polynomial_decay_schedule_with_warmup,
39
        get_scheduler,
40
        get_wsd_schedule,
41
    )
thomwolf's avatar
thomwolf committed
42

lukovnikov's avatar
lukovnikov committed
43

thomwolf's avatar
thomwolf committed
44
45
46
def unwrap_schedule(scheduler, num_steps=10):
    lrs = []
    for _ in range(num_steps):
47
        lrs.append(scheduler.get_lr()[0])
thomwolf's avatar
thomwolf committed
48
49
50
        scheduler.step()
    return lrs

51

52
53
54
def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
    lrs = []
    for step in range(num_steps):
55
        lrs.append(scheduler.get_lr()[0])
56
57
        scheduler.step()
        if step == num_steps // 2:
58
            with tempfile.TemporaryDirectory() as tmpdirname:
59
                file_name = os.path.join(tmpdirname, "schedule.bin")
60
61
62
63
64
65
                torch.save(scheduler.state_dict(), file_name)

                state_dict = torch.load(file_name)
                scheduler.load_state_dict(state_dict)
    return lrs

66

67
@require_torch
68
69
70
71
72
73
class OptimizationTest(unittest.TestCase):
    def assertListAlmostEqual(self, list1, list2, tol):
        self.assertEqual(len(list1), len(list2))
        for a, b in zip(list1, list2):
            self.assertAlmostEqual(a, b, delta=tol)

thomwolf's avatar
thomwolf committed
74
    def test_adam_w(self):
75
        w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
thomwolf's avatar
thomwolf committed
76
        target = torch.tensor([0.4, 0.2, -0.5])
77
        criterion = nn.MSELoss()
thomwolf's avatar
thomwolf committed
78
        # No warmup, constant schedule, no gradient clipping
thomwolf's avatar
thomwolf committed
79
        optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
80
        for _ in range(100):
thomwolf's avatar
thomwolf committed
81
            loss = criterion(w, target)
82
83
            loss.backward()
            optimizer.step()
84
            w.grad.detach_()  # No zero_grad() function on simple tensors. we do it ourselves.
thomwolf's avatar
thomwolf committed
85
            w.grad.zero_()
86
87
        self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)

88
89
90
    def test_adafactor(self):
        w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
        target = torch.tensor([0.4, 0.2, -0.5])
91
        criterion = nn.MSELoss()
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        # No warmup, constant schedule, no gradient clipping
        optimizer = Adafactor(
            params=[w],
            lr=1e-2,
            eps=(1e-30, 1e-3),
            clip_threshold=1.0,
            decay_rate=-0.8,
            beta1=None,
            weight_decay=0.0,
            relative_step=False,
            scale_parameter=False,
            warmup_init=False,
        )
        for _ in range(1000):
            loss = criterion(w, target)
            loss.backward()
            optimizer.step()
            w.grad.detach_()  # No zero_grad() function on simple tensors. we do it ourselves.
            w.grad.zero_()
        self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)

113

114
@require_torch
lukovnikov's avatar
lukovnikov committed
115
class ScheduleInitTest(unittest.TestCase):
116
    m = nn.Linear(50, 50) if is_torch_available() else None
117
    optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
thomwolf's avatar
thomwolf committed
118
119
    num_steps = 10

120
    def assertListAlmostEqual(self, list1, list2, tol, msg=None):
thomwolf's avatar
thomwolf committed
121
122
        self.assertEqual(len(list1), len(list2))
        for a, b in zip(list1, list2):
123
124
125
126
127
128
129
130
131
132
            self.assertAlmostEqual(a, b, delta=tol, msg=msg)

    def test_schedulers(self):
        common_kwargs = {"num_warmup_steps": 2, "num_training_steps": 10}
        # schedulers doct format
        # function: (sched_args_dict, expected_learning_rates)
        scheds = {
            get_constant_schedule: ({}, [10.0] * self.num_steps),
            get_constant_schedule_with_warmup: (
                {"num_warmup_steps": 4},
133
                [0.0, 2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0],
134
135
136
            ),
            get_linear_schedule_with_warmup: (
                {**common_kwargs},
137
                [0.0, 5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25],
138
139
140
            ),
            get_cosine_schedule_with_warmup: (
                {**common_kwargs},
141
                [0.0, 5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38],
142
143
144
            ),
            get_cosine_with_hard_restarts_schedule_with_warmup: (
                {**common_kwargs, "num_cycles": 2},
145
                [0.0, 5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46],
146
            ),
147
148
            get_polynomial_decay_schedule_with_warmup: (
                {**common_kwargs, "power": 2.0, "lr_end": 1e-7},
149
                [0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156],
150
            ),
151
152
153
154
            get_inverse_sqrt_schedule: (
                {"num_warmup_steps": 2},
                [0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
            ),
155
156
157
158
            get_wsd_schedule: (
                {"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1},
                [0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0],
            ),
159
160
161
162
163
164
        }

        for scheduler_func, data in scheds.items():
            kwargs, expected_learning_rates = data

            scheduler = scheduler_func(self.optimizer, **kwargs)
165
            self.assertEqual(len([scheduler.get_lr()[0]]), 1)
166
167
            lrs_1 = unwrap_schedule(scheduler, self.num_steps)
            self.assertListAlmostEqual(
Lysandre's avatar
Format  
Lysandre committed
168
169
170
171
                lrs_1,
                expected_learning_rates,
                tol=1e-2,
                msg=f"failed for {scheduler_func} in normal scheduler",
172
173
174
            )

            scheduler = scheduler_func(self.optimizer, **kwargs)
175
176
            if scheduler_func.__name__ != "get_constant_schedule":
                LambdaScheduleWrapper.wrap_scheduler(scheduler)  # wrap to test picklability of the schedule
177
            lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
178
            self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload")
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def test_get_scheduler(self):
        test_params = [
            {
                "name": "warmup_stable_decay",
                "optimizer": self.optimizer,
                "num_warmup_steps": 2,
                "scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3},
            },
            {
                "name": "warmup_stable_decay",
                "optimizer": self.optimizer,
                "num_warmup_steps": 2,
                "num_training_steps": 10,
                "scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3},
            },
            {"name": "cosine", "optimizer": self.optimizer, "num_warmup_steps": 2, "num_training_steps": 10},
        ]

        for param in test_params:
            self.assertTrue(get_scheduler(**param), msg=f"failed for {param['name']} in get_scheduler")

201
202
203
204
205
206
207
208
209
210
211

class LambdaScheduleWrapper:
    """See https://github.com/huggingface/transformers/issues/21689"""

    def __init__(self, fn):
        self.fn = fn

    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    @classmethod
212
213
    def wrap_scheduler(cls, scheduler):
        scheduler.lr_lambdas = list(map(cls, scheduler.lr_lambdas))