"vscode:/vscode.git/clone" did not exist on "2a101207d44b903c1cc9b4d968a4b24150413942"
test_fsdp_apply.py 2.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import unittest

from parameterized import parameterized
10
import pytest
11
12
import torch.nn as nn

13
14
from fairscale.utils import torch_version

15
16
17
18
19
20
21
22
23
24
from .test_fsdp import (
    CONFIG_OPTIONS,
    DistributedTest,
    NestedWrappedModule,
    TransformerWithSharedParams,
    rename_test,
    spawn_and_init,
)


25
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class TestApply(DistributedTest):
    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_transformer_weight_init(self, config):
        model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, TransformerWithSharedParams)
        test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01)
        spawn_and_init(test_fn)

    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_nested_wrapped_weight_init(self, config):
        model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, NestedWrappedModule)
        test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01)
        spawn_and_init(test_fn)


def model_init_and_apply_custom_weight_init(model_init_fn, *args, **kwargs):
    model = model_init_fn(*args, **kwargs)
    model.apply(init_bert_params_)
    return model


def init_bert_params_(module):
    """
    Initialize the weights specific to the BERT Model.
    """

    def normal_(data):
        # with FSDP, module params will be on CUDA, so we cast them back to CPU
        # so that the RNG is consistent with and without FSDP
        data.copy_(data.cpu().normal_(mean=0.0, std=0.02))

    if isinstance(module, nn.Linear):
        normal_(module.weight.data)
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, nn.Embedding):
        normal_(module.weight.data)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    if isinstance(module, nn.MultiheadAttention):
        normal_(module.in_proj_weight.data)


if __name__ == "__main__":
    unittest.main()