test_wrap.py 7.01 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

6
import functools
7
8
9
10
11
12
13
14
15
import os
import unittest
from unittest import mock

import torch
import torch.nn as nn
import torch.nn.functional as F

from fairscale.nn import FullyShardedDataParallel as FSDP
16
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
17
18
19
20
21
22
23
24
25
26
27
28
from fairscale.utils.testing import DummyProcessGroup


class TestAutoWrap(unittest.TestCase):
    def setUp(self) -> None:
        version = torch.__version__.split(".")[:2]
        major, minor = int(version[0]), int(version[1])
        if major < 1 or (major == 1 and minor < 6):
            raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast")
        self.process_group = DummyProcessGroup(rank=0, size=1)

    def test_wrap(self):
29
        with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
30
31
32
33
34
35
36
37
38
            layer = wrap(nn.Linear(5, 5))
        assert isinstance(layer, FSDP)
        assert layer.flatten_parameters is False

    def test_wrap_disabled_outside_context(self):
        layer = wrap(nn.Linear(5, 5))
        assert isinstance(layer, nn.Linear)

    def test_wrap_override_defaults(self):
39
        with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
40
41
42
43
44
45
46
47
48
            layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
        assert isinstance(layer, FSDP)
        assert layer.flatten_parameters

    def test_auto_wrap(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
49
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
50
51
52
            sequential = nn.Sequential(
                nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
            )
53
54
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
55
56
57
58
59
60
61
62
63
64
65
66
        assert isinstance(model, FSDP)
        assert isinstance(model.module[0], nn.Linear)
        assert isinstance(model.module[1], nn.Linear)
        assert isinstance(model.module[2], FSDP)
        assert isinstance(model.module[2].module[0], nn.Linear)
        assert isinstance(model.module[2].module[1], nn.Linear)

    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params.
        """
67
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
68
            sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
69
70
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
71
72
73
74
75
76
77
78
79
        assert isinstance(model, nn.ModuleList)
        assert isinstance(model[0], nn.Linear)
        assert isinstance(model[1], nn.Linear)

    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
80
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
81
            sequential = nn.ModuleList([nn.Linear(10, 10)])
82
83
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
84
85
86
        assert isinstance(model, nn.ModuleList)
        assert isinstance(model[0], FSDP)

87
    def test_auto_wrap_preset_force_leaf(self):
88
        """
89
        Test to ensure force-leaf modules are not wrapped, and children are not wrapped.
90
        """
91
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
92
            sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
93
94
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
95
96
97
98
99
        assert isinstance(model.module[0], FSDP)
        # Assert children of multihead attention are not wrapped
        assert isinstance(model.module[1], nn.MultiheadAttention)
        assert isinstance(model.module[1].out_proj, nn.Linear)

100
    def test_auto_wrap_preset_force_leaf_custom(self):
101
        """
102
        Test to ensure force-leaf modules are not wrapped.
103
        """
104
105
106
107
108
109
110
111
112
113
114
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy,
            min_num_params=40,
            force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union({nn.Linear}),
        )
        with enable_wrap(
            auto_wrap_policy=my_auto_wrap_policy,
            wrapper_cls=FSDP,
            process_group=self.process_group,
            flatten_parameters=False,
        ):
115
            sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]))
116
            model = auto_wrap(sequential)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        # Model was wrapped in FSDP as no inner modules were wrapped.
        assert isinstance(model, FSDP)
        assert isinstance(model.module[0], nn.Linear)
        assert isinstance(model.module[1], nn.ModuleList)

    # todo: currently complains that address is in use, not sure why since I clear the proc group.
    # def test_auto_wrap_smoke(self):
    #     self._auto_wrap_smoke_test(enable_mixed_precision=False)

    def test_auto_wrap_smoke_autocast(self):
        """
        Ensure we can do a forward/backward through an auto-wrapped model.
        """
        self._auto_wrap_smoke_test(enable_mixed_precision=True)

    @mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "12345"}, clear=True)
    @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
    def _auto_wrap_smoke_test(self, enable_mixed_precision):
        from torch.cuda.amp import autocast

        device = torch.device("cuda")
        torch.cuda.set_device(0)
        torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

141
        with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision):
142
143
144
            sequential = nn.Sequential(
                nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
            )
145
146
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
147
148
149
150
151
152
153
154
        model.to(device)
        input = torch.rand((1, 5), dtype=torch.float).to(device)

        with autocast(enabled=enable_mixed_precision):
            output = model(input)
            loss = F.mse_loss(input, output)
        loss.backward()
        torch.distributed.destroy_process_group()