"docs/source/en/tasks/audio_classification.md" did not exist on "a6d8a149a8defaf02941c61ff2b419e60f4855ab"
test_optimization.py 6.67 KB
Newer Older
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16

17
import os
18
import tempfile
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import unittest
20

21
from transformers import is_torch_available
22
from transformers.testing_utils import require_torch
Aymeric Augustin's avatar
Aymeric Augustin committed
23
24


25
if is_torch_available():
thomwolf's avatar
thomwolf committed
26
    import torch
27
    from torch import nn
thomwolf's avatar
thomwolf committed
28

29
    from transformers import (
30
        Adafactor,
31
32
33
34
35
        AdamW,
        get_constant_schedule,
        get_constant_schedule_with_warmup,
        get_cosine_schedule_with_warmup,
        get_cosine_with_hard_restarts_schedule_with_warmup,
36
        get_inverse_sqrt_schedule,
37
        get_linear_schedule_with_warmup,
38
        get_polynomial_decay_schedule_with_warmup,
39
    )
thomwolf's avatar
thomwolf committed
40

lukovnikov's avatar
lukovnikov committed
41

thomwolf's avatar
thomwolf committed
42
43
44
def unwrap_schedule(scheduler, num_steps=10):
    lrs = []
    for _ in range(num_steps):
45
        lrs.append(scheduler.get_lr()[0])
thomwolf's avatar
thomwolf committed
46
47
48
        scheduler.step()
    return lrs

49

50
51
52
def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
    lrs = []
    for step in range(num_steps):
53
        lrs.append(scheduler.get_lr()[0])
54
55
        scheduler.step()
        if step == num_steps // 2:
56
            with tempfile.TemporaryDirectory() as tmpdirname:
57
                file_name = os.path.join(tmpdirname, "schedule.bin")
58
59
60
61
62
63
                torch.save(scheduler.state_dict(), file_name)

                state_dict = torch.load(file_name)
                scheduler.load_state_dict(state_dict)
    return lrs

64

65
@require_torch
66
67
68
69
70
71
class OptimizationTest(unittest.TestCase):
    def assertListAlmostEqual(self, list1, list2, tol):
        self.assertEqual(len(list1), len(list2))
        for a, b in zip(list1, list2):
            self.assertAlmostEqual(a, b, delta=tol)

thomwolf's avatar
thomwolf committed
72
    def test_adam_w(self):
73
        w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
thomwolf's avatar
thomwolf committed
74
        target = torch.tensor([0.4, 0.2, -0.5])
75
        criterion = nn.MSELoss()
thomwolf's avatar
thomwolf committed
76
        # No warmup, constant schedule, no gradient clipping
thomwolf's avatar
thomwolf committed
77
        optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
78
        for _ in range(100):
thomwolf's avatar
thomwolf committed
79
            loss = criterion(w, target)
80
81
            loss.backward()
            optimizer.step()
82
            w.grad.detach_()  # No zero_grad() function on simple tensors. we do it ourselves.
thomwolf's avatar
thomwolf committed
83
            w.grad.zero_()
84
85
        self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)

86
87
88
    def test_adafactor(self):
        w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
        target = torch.tensor([0.4, 0.2, -0.5])
89
        criterion = nn.MSELoss()
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        # No warmup, constant schedule, no gradient clipping
        optimizer = Adafactor(
            params=[w],
            lr=1e-2,
            eps=(1e-30, 1e-3),
            clip_threshold=1.0,
            decay_rate=-0.8,
            beta1=None,
            weight_decay=0.0,
            relative_step=False,
            scale_parameter=False,
            warmup_init=False,
        )
        for _ in range(1000):
            loss = criterion(w, target)
            loss.backward()
            optimizer.step()
            w.grad.detach_()  # No zero_grad() function on simple tensors. we do it ourselves.
            w.grad.zero_()
        self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)

111

112
@require_torch
lukovnikov's avatar
lukovnikov committed
113
class ScheduleInitTest(unittest.TestCase):
114
    m = nn.Linear(50, 50) if is_torch_available() else None
115
    optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
thomwolf's avatar
thomwolf committed
116
117
    num_steps = 10

118
    def assertListAlmostEqual(self, list1, list2, tol, msg=None):
thomwolf's avatar
thomwolf committed
119
120
        self.assertEqual(len(list1), len(list2))
        for a, b in zip(list1, list2):
121
122
123
124
125
126
127
128
129
130
            self.assertAlmostEqual(a, b, delta=tol, msg=msg)

    def test_schedulers(self):
        common_kwargs = {"num_warmup_steps": 2, "num_training_steps": 10}
        # schedulers doct format
        # function: (sched_args_dict, expected_learning_rates)
        scheds = {
            get_constant_schedule: ({}, [10.0] * self.num_steps),
            get_constant_schedule_with_warmup: (
                {"num_warmup_steps": 4},
131
                [0.0, 2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0],
132
133
134
            ),
            get_linear_schedule_with_warmup: (
                {**common_kwargs},
135
                [0.0, 5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25],
136
137
138
            ),
            get_cosine_schedule_with_warmup: (
                {**common_kwargs},
139
                [0.0, 5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38],
140
141
142
            ),
            get_cosine_with_hard_restarts_schedule_with_warmup: (
                {**common_kwargs, "num_cycles": 2},
143
                [0.0, 5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46],
144
            ),
145
146
            get_polynomial_decay_schedule_with_warmup: (
                {**common_kwargs, "power": 2.0, "lr_end": 1e-7},
147
                [0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156],
148
            ),
149
150
151
152
            get_inverse_sqrt_schedule: (
                {"num_warmup_steps": 2},
                [0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
            ),
153
154
155
156
157
158
        }

        for scheduler_func, data in scheds.items():
            kwargs, expected_learning_rates = data

            scheduler = scheduler_func(self.optimizer, **kwargs)
159
            self.assertEqual(len([scheduler.get_lr()[0]]), 1)
160
161
            lrs_1 = unwrap_schedule(scheduler, self.num_steps)
            self.assertListAlmostEqual(
Lysandre's avatar
Format  
Lysandre committed
162
163
164
165
                lrs_1,
                expected_learning_rates,
                tol=1e-2,
                msg=f"failed for {scheduler_func} in normal scheduler",
166
167
168
            )

            scheduler = scheduler_func(self.optimizer, **kwargs)
169
170
            if scheduler_func.__name__ != "get_constant_schedule":
                LambdaScheduleWrapper.wrap_scheduler(scheduler)  # wrap to test picklability of the schedule
171
            lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
172
            self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload")
173
174
175
176
177
178
179
180
181
182
183
184
185
186


class LambdaScheduleWrapper:
    """See https://github.com/huggingface/transformers/issues/21689"""

    def __init__(self, fn):
        self.fn = fn

    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    @classmethod
    def wrap_scheduler(self, scheduler):
        scheduler.lr_lambdas = list(map(self, scheduler.lr_lambdas))