test_flatten_params_wrapper.py 16 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
""" Test FlattenParamsWrapper on CPU and GPU (FP32 & FP16 on GPU). """
Myle Ott's avatar
Myle Ott committed
7

8
from collections import OrderedDict
Myle Ott's avatar
Myle Ott committed
9
10
11
import unittest

import torch
12

13
from fair_dev.testing.testing import objects_are_equal
Myle Ott's avatar
Myle Ott committed
14
15
16
17
from fairscale.nn import FlattenParamsWrapper


class TestFlattenParams(unittest.TestCase):
18
    """Base test class and used for CPU case."""
19

20
21
    def _get_module_init_fns(self):
        return [
22
            self._get_basic_linear_module,
23
            self._get_shared_params_transformer,
24
25
            self._get_2_flatten_group_linear_module,
            self._get_2_flatten_group_linear_module_with_names,
26
27
        ]

28
29
30
31
32
33
34
35
36
37
38
39
40
41
    def _get_empty_module(self, seed=0):
        torch.manual_seed(seed)  # keep everything deterministic

        class Test(torch.nn.Module):
            def forward(self, x):
                return x + 1

        module = Test()

        def get_input(device, dtype):
            torch.manual_seed(1)  # keep everything deterministic
            return torch.rand(1).to(device=device, dtype=dtype)

        module.get_input = get_input
42
43
        module.param_list = None  # No param_list to FPW.
        module.flat_param_names = None  # No flat_param_names to FPW.
44
45
        return module

Myle Ott's avatar
Myle Ott committed
46
47
48
    def _get_transformer(self, seed=0):
        torch.manual_seed(seed)  # keep everything deterministic
        module = torch.nn.Transformer(
49
50
51
52
53
            d_model=32,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=128,
            dropout=0.1,
Myle Ott's avatar
Myle Ott committed
54
55
        )
        module.register_buffer("dummy_buffer", torch.tensor(1.0))
56
57
58
59
60
61
62
63

        def get_input(device, dtype):
            torch.manual_seed(1)  # keep everything deterministic
            src = torch.rand(20, 8, 32).to(device=device, dtype=dtype)  # T x B x C
            tgt = torch.rand(10, 8, 32).to(device=device, dtype=dtype)  # T x B x C
            return (src, tgt)

        module.get_input = get_input
64
65
        module.param_list = None  # No param_list to FPW.
        module.flat_param_names = None  # No flat_param_names to FPW.
Myle Ott's avatar
Myle Ott committed
66
67
68
69
70
71
72
73
74
75
        return module

    def _get_shared_params_transformer(self, seed=0):
        module = self._get_transformer(seed=seed)
        # share the FFNs
        for enc_layer, dec_layer in zip(module.encoder.layers, module.decoder.layers):
            dec_layer.linear1.weight = enc_layer.linear1.weight
            dec_layer.linear2.weight = enc_layer.linear2.weight
        return module

76
    def _get_basic_linear_module(self, seed=0):
77
        module = torch.nn.Sequential(
78
79
80
            torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 8)),
            torch.nn.Sequential(torch.nn.Linear(8, 16)),
            torch.nn.Linear(16, 4),
81
82
83
84
85
86
87
        )

        def get_input(device, dtype):
            torch.manual_seed(1)  # keep everything deterministic
            return (torch.rand(8, 4).to(device=device, dtype=dtype),)

        module.get_input = get_input
88
89
        module.param_list = None  # No param_list to FPW.
        module.flat_param_names = None  # No flat_param_names to FPW.
90
91
        return module

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    def _get_2_flatten_group_linear_module(self, seed=0):
        module = torch.nn.Sequential(
            torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 16)),
            torch.nn.Linear(16, 4),
        )

        def get_input(device, dtype):
            torch.manual_seed(1)  # keep everything deterministic
            return (torch.rand(8, 4).to(device=device, dtype=dtype),)

        module.get_input = get_input
        assert len(module) == 2, "next line assumes a len==2 sequential module"
        module.param_list = [list(module[0].parameters()), list(module[1].parameters())]
        module.flat_param_names = None  # No flat_param_names to FPW.
        return module

    def _get_2_flatten_group_linear_module_with_names(self, seed=0):
        module = torch.nn.Sequential(
            torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 16)),
            torch.nn.Linear(16, 4),
        )

        def get_input(device, dtype):
            torch.manual_seed(1)  # keep everything deterministic
            return (torch.rand(8, 4).to(device=device, dtype=dtype),)

        module.get_input = get_input
        assert len(module) == 2, "next line assumes a len==2 sequential module"
        module.param_list = [list(module[0].parameters()), list(module[1].parameters())]
        module.flat_param_names = ["layer1", "layer2"]
        return module

    def _compute_output(self, module):
Myle Ott's avatar
Myle Ott committed
125
126
        device = next(module.parameters()).device
        dtype = next(module.parameters()).dtype
127
128
        input = module.get_input(device, dtype)
        return module(*input)
Myle Ott's avatar
Myle Ott committed
129
130
131

    def _get_pnorm_after_step(self, module):
        optim = torch.optim.SGD(module.parameters(), lr=0.01)
132
        loss = self._compute_output(module).sum()
Myle Ott's avatar
Myle Ott committed
133
134
135
136
137
        loss.backward()
        optim.step()
        return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()]))

    def _test_num_params(self, module):
138
        """Make sure numel of params are the same after flatten."""
Myle Ott's avatar
Myle Ott committed
139
140
141
142
143
144
145
146
147
        ref_num_params = sum(p.numel() for p in module.parameters())

        flat_module = FlattenParamsWrapper(module)
        flat_num_params = sum(p.numel() for p in flat_module.parameters())

        assert ref_num_params == flat_num_params
        assert flat_num_params == flat_module.flat_param.numel()

    def _test_output(self, module):
148
        ref_output = self._compute_output(module)
Myle Ott's avatar
Myle Ott committed
149
150

        flat_module = FlattenParamsWrapper(module)
151
        flat_output = self._compute_output(flat_module)
Myle Ott's avatar
Myle Ott committed
152
153
154
        assert objects_are_equal(ref_output, flat_output)

    def test_partial_flattening(self):
155
        """Testing some parameters are flatten, with others left non-flatten."""
Myle Ott's avatar
Myle Ott committed
156
157
158
        module = self._get_transformer()
        num_params = sum(p.numel() for p in module.parameters())

159
        params_to_flatten = list(module.encoder.layers[1].parameters()) + list(module.decoder.layers[0].parameters())
Myle Ott's avatar
Myle Ott committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        num_params_to_flatten = sum(p.numel() for p in params_to_flatten)

        module = FlattenParamsWrapper(module, param_list=params_to_flatten)
        assert module.flat_param.numel() == num_params_to_flatten
        assert sum(p.numel() for p in module.parameters()) == num_params

        # flattened parameters are removed
        assert len(list(module.encoder.layers[1].parameters())) == 0
        assert len(list(module.decoder.layers[0].parameters())) == 0

        # non-flattened parameters remain
        assert len(list(module.encoder.layers[0].parameters())) > 0
        assert len(list(module.decoder.layers[1].parameters())) > 0

        # test that changing the module dtype works properly
        orig_dtype = params_to_flatten[0].dtype
        new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16
        assert module.flat_param.dtype == orig_dtype
178
        assert all(p.dtype == orig_dtype for p in module.encoder.layers[0].parameters())
Myle Ott's avatar
Myle Ott committed
179
180
181
182
        module = module.to(dtype=new_dtype)
        assert module.flat_param.dtype == new_dtype
        assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())

183
    def test_two_flattening_group(self):
184
        """Testing 2 flatten groups."""
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        module = self._get_transformer()
        num_params = sum(p.numel() for p in module.parameters())

        params_to_flatten1 = list(module.encoder.layers[1].parameters()) + list(module.decoder.layers[0].parameters())
        params_to_flatten2 = list(module.encoder.layers[0].parameters()) + list(module.decoder.layers[1].parameters())
        num_params_to_flatten1 = sum(p.numel() for p in params_to_flatten1)
        num_params_to_flatten2 = sum(p.numel() for p in params_to_flatten2)

        module = FlattenParamsWrapper(module, param_list=[params_to_flatten1, params_to_flatten2])
        assert module.flat_params[0].numel() == num_params_to_flatten1
        assert module.flat_params[1].numel() == num_params_to_flatten2
        assert sum(p.numel() for p in module.parameters()) == num_params

    def test_flatten_nothing(self):
199
        """Testing nothing is flatten case."""
200
        module = self._get_transformer()
201
        ref_out = self._compute_output(module)
202
203
204
205
206
207
208
209
        ref_state_dict = module.state_dict()
        for k, v in ref_state_dict.items():
            ref_state_dict[k] = v.clone()
        module = FlattenParamsWrapper(module, param_list=[[]])
        fpw_state_dict = module.state_dict()
        assert ref_state_dict.keys() == fpw_state_dict.keys()
        for k, v in ref_state_dict.items():
            torch.testing.assert_allclose(v, fpw_state_dict[k])
210
        fpw_out = self._compute_output(module)
211
212
213
        torch.testing.assert_allclose(ref_out, fpw_out)

    def test_empty_module(self):
214
        """Test module without any param."""
215
216
217
218
219
220
221
222
223
        module = self._get_empty_module()
        in_data = torch.rand(1)
        ref_out = module(in_data)
        module = FlattenParamsWrapper(module)
        assert len(list(module.parameters())) == 0
        assert len(module.state_dict()) == 0
        fpw_out = module(in_data)
        torch.testing.assert_allclose(ref_out, fpw_out)

Myle Ott's avatar
Myle Ott committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    def test_num_params(self):
        module = self._get_transformer()
        self._test_num_params(module)

    def test_shared_params_num_params(self):
        module = self._get_shared_params_transformer()
        self._test_num_params(module)

    def test_output(self):
        module = self._get_transformer()
        self._test_output(module)

    def test_shared_params_output(self):
        module = self._get_shared_params_transformer()
        self._test_output(module)

    def test_shared_params_pnorm_after_step(self):
        # incorrect parameter sharing is likely to cause problems after an
        # optimization step
        module = self._get_shared_params_transformer()
        ref_pnorm_after_step = self._get_pnorm_after_step(module)

        module = self._get_shared_params_transformer()  # recreate
        flat_module = FlattenParamsWrapper(module)
        flat_pnorm_after_step = self._get_pnorm_after_step(flat_module)

        torch.testing.assert_allclose(ref_pnorm_after_step, flat_pnorm_after_step)

    def test_state_dict_equality(self):
253
254
255
256
        """Test that unflattened state dict matches original (unwrapped) one."""
        modules_to_test = [init_fn() for init_fn in self._get_module_init_fns()]
        for module in modules_to_test:
            ref_state_dict = module.state_dict()
Myle Ott's avatar
Myle Ott committed
257

258
259
            flat_module = FlattenParamsWrapper(module)
            flat_state_dict = flat_module.state_dict()
Myle Ott's avatar
Myle Ott committed
260

261
262
263
264
            assert (
                ref_state_dict.keys() == flat_state_dict.keys()
            ), f"{ref_state_dict.keys()} != {flat_state_dict.keys()}"
            assert objects_are_equal(ref_state_dict, flat_state_dict), f"{ref_state_dict} != {flat_state_dict}"
Myle Ott's avatar
Myle Ott committed
265
266

    def test_load_state_dict(self):
267
268
269
270
        """Test that original (unwrapped) state_dict can be loaded in wrapped module."""
        for module_init_fn in self._get_module_init_fns():
            module = module_init_fn()
            ref_state_dict = module.state_dict()
271
            ref_output = self._compute_output(module)
272
273

            module = module_init_fn(seed=1234)
274
275
276
            flat_module = FlattenParamsWrapper(
                module, param_list=module.param_list, flat_param_names=module.flat_param_names
            )
277
278
279

            # This should work without the unflatten_params context manager
            flat_module.load_state_dict(ref_state_dict)
280
            flat_output = self._compute_output(flat_module)
281
282
283
284
285
            assert objects_are_equal(ref_output, flat_output)

            # And it should work with the context manager too
            with flat_module.unflatten_params():
                flat_module.load_state_dict(ref_state_dict)
286
            flat_output = self._compute_output(flat_module)
287
            assert objects_are_equal(ref_output, flat_output)
Myle Ott's avatar
Myle Ott committed
288
289

    def test_flat_state_dict(self):
290
291
        """Test that flat state dict can be reloaded and produces the same results."""
        for module_init_fn in self._get_module_init_fns():
292
293
294
295
296
            orig_module = module_init_fn()
            flat_module = FlattenParamsWrapper(
                orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
            )
            ref_output = self._compute_output(flat_module)
Myle Ott's avatar
Myle Ott committed
297

298
            flat_state_dict = flat_module.flat_state_dict()
Myle Ott's avatar
Myle Ott committed
299

300
301
302
303
            orig_module = module_init_fn(seed=1234)
            new_module = FlattenParamsWrapper(
                orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
            )
304
            new_module.load_state_dict(flat_state_dict)
305
            new_output = self._compute_output(new_module)
Myle Ott's avatar
Myle Ott committed
306

307
            assert objects_are_equal(ref_output, new_output)
Myle Ott's avatar
Myle Ott committed
308

309
    def test_unflatten_params(self):
310
        """Testing using external flatten params tensors as module's params' backing data."""
311
        for module_init_fn in self._get_module_init_fns():
312
313
314
315
316
317
            orig_module = module_init_fn()
            module = FlattenParamsWrapper(
                orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
            )

            # keep a list of buffer's key to be used for verification below.
318
319
320
            buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()}

            def clone_state_dict():
321
                """Return a copy of the module's current state via state_dict() API."""
322
323
                return OrderedDict((k, v.clone()) for k, v in module.state_dict().items())

324
325
            ref_flat_params = [fp.clone() for fp in module.flat_params]
            # Get the current state as a reference.
326
327
            with module.unflatten_params():
                ref_state_dict = clone_state_dict()
328
329
            for ref_fp in ref_flat_params:
                assert not torch.all(ref_fp == 0.0)  # Should not all be 0s.
330

331
332
333
            # get new_state_dict with supplied new_flat_params.
            new_flat_params = [torch.full_like(fp, fill_value=42.0) for fp in module.flat_params]
            with module.unflatten_params(flat_params=new_flat_params):
334
                new_state_dict = clone_state_dict()
335
336
337
338
339
340
341
342

            # confirm that unflatten_params reflects values from new_flat_param
            assert new_state_dict.keys() == ref_state_dict.keys()
            for k, v in new_state_dict.items():
                if k in buffers:  # buffers are not changed
                    torch.testing.assert_allclose(v, ref_state_dict[k])
                else:  # params reflect new_flat_param value
                    torch.testing.assert_allclose(v, torch.ones_like(v) * 42.0)
343
344

            # after context manager exits, we go back to previous (reference) state
345
346
347
348
349
            assert len(module.flat_params) == len(ref_flat_params)
            for i in range(len(module.flat_params)):
                torch.testing.assert_allclose(module.flat_params[i], ref_flat_params[i])

            # get another copy of state from the module (without external backing data)
350
351
            with module.unflatten_params():
                ref_state_dict2 = clone_state_dict()
352
353
354

            # Verify it is still the same.
            assert objects_are_equal(ref_state_dict, ref_state_dict2)
355
356
357

            # if we load the new_state_dict, then the flat param should match new_flat_param
            module.load_state_dict(new_state_dict)
358
359
360
            assert len(module.flat_params) == len(new_flat_params)
            for i in range(len(module.flat_params)):
                torch.testing.assert_allclose(module.flat_params[i], new_flat_params[i])
361

Myle Ott's avatar
Myle Ott committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestFlattenParamsCUDA(TestFlattenParams):
    def _get_transformer(self, seed=0):
        module = super()._get_transformer(seed=seed)
        return module.cuda()


@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestFlattenParamsCUDAHalf(TestFlattenParams):
    def _get_transformer(self, seed=0):
        module = super()._get_transformer(seed=seed)
        return module.cuda().half()


if __name__ == "__main__":
    unittest.main()