test_fsdp_apply.py 2.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import unittest

from parameterized import parameterized
import torch.nn as nn

from .test_fsdp import (
    CONFIG_OPTIONS,
    DistributedTest,
    NestedWrappedModule,
    TransformerWithSharedParams,
    rename_test,
    spawn_and_init,
)


class TestApply(DistributedTest):
    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_transformer_weight_init(self, config):
        model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, TransformerWithSharedParams)
        test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01)
        spawn_and_init(test_fn)

    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_nested_wrapped_weight_init(self, config):
        model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, NestedWrappedModule)
        test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01)
        spawn_and_init(test_fn)


def model_init_and_apply_custom_weight_init(model_init_fn, *args, **kwargs):
    model = model_init_fn(*args, **kwargs)
    model.apply(init_bert_params_)
    return model


def init_bert_params_(module):
    """
    Initialize the weights specific to the BERT Model.
    """

    def normal_(data):
        # with FSDP, module params will be on CUDA, so we cast them back to CPU
        # so that the RNG is consistent with and without FSDP
        data.copy_(data.cpu().normal_(mean=0.0, std=0.02))

    if isinstance(module, nn.Linear):
        normal_(module.weight.data)
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, nn.Embedding):
        normal_(module.weight.data)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    if isinstance(module, nn.MultiheadAttention):
        normal_(module.in_proj_weight.data)


if __name__ == "__main__":
    unittest.main()