test_optimization.py 6.92 KB
Newer Older
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16

17
import os
18
import tempfile
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import unittest
20

21
from transformers import is_torch_available
22
from transformers.testing_utils import require_torch
Aymeric Augustin's avatar
Aymeric Augustin committed
23
24


25
if is_torch_available():
thomwolf's avatar
thomwolf committed
26
    import torch
27
    from torch import nn
thomwolf's avatar
thomwolf committed
28

29
    from transformers import (
30
        Adafactor,
31
32
33
34
35
        AdamW,
        get_constant_schedule,
        get_constant_schedule_with_warmup,
        get_cosine_schedule_with_warmup,
        get_cosine_with_hard_restarts_schedule_with_warmup,
36
        get_inverse_sqrt_schedule,
37
        get_linear_schedule_with_warmup,
38
        get_polynomial_decay_schedule_with_warmup,
39
        get_wsd_schedule,
40
    )
thomwolf's avatar
thomwolf committed
41

lukovnikov's avatar
lukovnikov committed
42

thomwolf's avatar
thomwolf committed
43
44
45
def unwrap_schedule(scheduler, num_steps=10):
    lrs = []
    for _ in range(num_steps):
46
        lrs.append(scheduler.get_lr()[0])
thomwolf's avatar
thomwolf committed
47
48
49
        scheduler.step()
    return lrs

50

51
52
53
def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
    lrs = []
    for step in range(num_steps):
54
        lrs.append(scheduler.get_lr()[0])
55
56
        scheduler.step()
        if step == num_steps // 2:
57
            with tempfile.TemporaryDirectory() as tmpdirname:
58
                file_name = os.path.join(tmpdirname, "schedule.bin")
59
60
61
62
63
64
                torch.save(scheduler.state_dict(), file_name)

                state_dict = torch.load(file_name)
                scheduler.load_state_dict(state_dict)
    return lrs

65

66
@require_torch
67
68
69
70
71
72
class OptimizationTest(unittest.TestCase):
    def assertListAlmostEqual(self, list1, list2, tol):
        self.assertEqual(len(list1), len(list2))
        for a, b in zip(list1, list2):
            self.assertAlmostEqual(a, b, delta=tol)

thomwolf's avatar
thomwolf committed
73
    def test_adam_w(self):
74
        w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
thomwolf's avatar
thomwolf committed
75
        target = torch.tensor([0.4, 0.2, -0.5])
76
        criterion = nn.MSELoss()
thomwolf's avatar
thomwolf committed
77
        # No warmup, constant schedule, no gradient clipping
thomwolf's avatar
thomwolf committed
78
        optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
79
        for _ in range(100):
thomwolf's avatar
thomwolf committed
80
            loss = criterion(w, target)
81
82
            loss.backward()
            optimizer.step()
83
            w.grad.detach_()  # No zero_grad() function on simple tensors. we do it ourselves.
thomwolf's avatar
thomwolf committed
84
            w.grad.zero_()
85
86
        self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)

87
88
89
    def test_adafactor(self):
        w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
        target = torch.tensor([0.4, 0.2, -0.5])
90
        criterion = nn.MSELoss()
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        # No warmup, constant schedule, no gradient clipping
        optimizer = Adafactor(
            params=[w],
            lr=1e-2,
            eps=(1e-30, 1e-3),
            clip_threshold=1.0,
            decay_rate=-0.8,
            beta1=None,
            weight_decay=0.0,
            relative_step=False,
            scale_parameter=False,
            warmup_init=False,
        )
        for _ in range(1000):
            loss = criterion(w, target)
            loss.backward()
            optimizer.step()
            w.grad.detach_()  # No zero_grad() function on simple tensors. we do it ourselves.
            w.grad.zero_()
        self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)

112

113
@require_torch
lukovnikov's avatar
lukovnikov committed
114
class ScheduleInitTest(unittest.TestCase):
115
    m = nn.Linear(50, 50) if is_torch_available() else None
116
    optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
thomwolf's avatar
thomwolf committed
117
118
    num_steps = 10

119
    def assertListAlmostEqual(self, list1, list2, tol, msg=None):
thomwolf's avatar
thomwolf committed
120
121
        self.assertEqual(len(list1), len(list2))
        for a, b in zip(list1, list2):
122
123
124
125
126
127
128
129
130
131
            self.assertAlmostEqual(a, b, delta=tol, msg=msg)

    def test_schedulers(self):
        common_kwargs = {"num_warmup_steps": 2, "num_training_steps": 10}
        # schedulers doct format
        # function: (sched_args_dict, expected_learning_rates)
        scheds = {
            get_constant_schedule: ({}, [10.0] * self.num_steps),
            get_constant_schedule_with_warmup: (
                {"num_warmup_steps": 4},
132
                [0.0, 2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0],
133
134
135
            ),
            get_linear_schedule_with_warmup: (
                {**common_kwargs},
136
                [0.0, 5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25],
137
138
139
            ),
            get_cosine_schedule_with_warmup: (
                {**common_kwargs},
140
                [0.0, 5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38],
141
142
143
            ),
            get_cosine_with_hard_restarts_schedule_with_warmup: (
                {**common_kwargs, "num_cycles": 2},
144
                [0.0, 5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46],
145
            ),
146
147
            get_polynomial_decay_schedule_with_warmup: (
                {**common_kwargs, "power": 2.0, "lr_end": 1e-7},
148
                [0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156],
149
            ),
150
151
152
153
            get_inverse_sqrt_schedule: (
                {"num_warmup_steps": 2},
                [0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
            ),
154
155
156
157
            get_wsd_schedule: (
                {"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1},
                [0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0],
            ),
158
159
160
161
162
163
        }

        for scheduler_func, data in scheds.items():
            kwargs, expected_learning_rates = data

            scheduler = scheduler_func(self.optimizer, **kwargs)
164
            self.assertEqual(len([scheduler.get_lr()[0]]), 1)
165
166
            lrs_1 = unwrap_schedule(scheduler, self.num_steps)
            self.assertListAlmostEqual(
Lysandre's avatar
Format  
Lysandre committed
167
168
169
170
                lrs_1,
                expected_learning_rates,
                tol=1e-2,
                msg=f"failed for {scheduler_func} in normal scheduler",
171
172
173
            )

            scheduler = scheduler_func(self.optimizer, **kwargs)
174
175
            if scheduler_func.__name__ != "get_constant_schedule":
                LambdaScheduleWrapper.wrap_scheduler(scheduler)  # wrap to test picklability of the schedule
176
            lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
177
            self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload")
178
179
180
181
182
183
184
185
186
187
188
189
190
191


class LambdaScheduleWrapper:
    """See https://github.com/huggingface/transformers/issues/21689"""

    def __init__(self, fn):
        self.fn = fn

    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    @classmethod
    def wrap_scheduler(self, scheduler):
        scheduler.lr_lambdas = list(map(self, scheduler.lr_lambdas))