test_wrap.py 8.25 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

6
import functools
7
import os
8
import random
9
10
11
12
13
14
15
import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F

from fairscale.nn import FullyShardedDataParallel as FSDP
16
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
17
18
from fairscale.utils.testing import DummyProcessGroup

19
20
21
22
23
try:
    from torch.cuda.amp import autocast
except ImportError:
    autocast = None  # type: ignore

24
25
26

class TestAutoWrap(unittest.TestCase):
    def setUp(self) -> None:
27
28
        # For all the tests here, we use a fake group and flatten being False since those should
        # not affect how wrapping work.
29
30
31
        self.process_group = DummyProcessGroup(rank=0, size=1)

    def test_wrap(self):
32
        with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
33
34
35
36
37
38
39
40
41
            layer = wrap(nn.Linear(5, 5))
        assert isinstance(layer, FSDP)
        assert layer.flatten_parameters is False

    def test_wrap_disabled_outside_context(self):
        layer = wrap(nn.Linear(5, 5))
        assert isinstance(layer, nn.Linear)

    def test_wrap_override_defaults(self):
42
        with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
43
44
45
46
47
48
49
50
            layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
        assert isinstance(layer, FSDP)
        assert layer.flatten_parameters

    def test_auto_wrap(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
51
52
        Root is not wrapped given there are not enough unwrapped params left and skip_params_check_for_root
        is not set.
53
        """
54
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            sequential = nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=60)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
        assert isinstance(model, nn.Sequential)
        assert isinstance(model[0], nn.Linear)
        assert isinstance(model[1], FSDP)
        assert isinstance(model[1].module[0], nn.Linear)
        assert isinstance(model[1].module[1], nn.Linear)

    def test_auto_wrap_skip_root_checks(self):
        """
        Similar test as before but this time we set skip_params_check_for_root=True in the wrap policy.
        So in this case the root is wrapped even without enough remaining unwrapped params.
        """
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
            sequential = nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
            my_auto_wrap_policy = functools.partial(
                default_auto_wrap_policy, min_num_params=60, skip_params_check_for_root=True
73
            )
74
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
75
76
        assert isinstance(model, FSDP)
        assert isinstance(model.module[0], nn.Linear)
77
78
79
        assert isinstance(model.module[1], FSDP)
        assert isinstance(model.module[1].module[0], nn.Linear)
        assert isinstance(model.module[1].module[1], nn.Linear)
80
81
82
83
84
85

    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params.
        """
86
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
87
            sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
88
89
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
90
91
92
93
94
95
96
97
98
        assert isinstance(model, nn.ModuleList)
        assert isinstance(model[0], nn.Linear)
        assert isinstance(model[1], nn.Linear)

    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
99
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
100
            sequential = nn.ModuleList([nn.Linear(10, 10)])
101
102
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
103
104
105
        assert isinstance(model, nn.ModuleList)
        assert isinstance(model[0], FSDP)

106
    def test_auto_wrap_preset_force_leaf(self):
107
        """
108
        Test to ensure force-leaf modules are not wrapped, and children are not wrapped.
109
        """
110
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
111
            sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
112
113
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
114
115
116
117
118
        assert isinstance(model.module[0], FSDP)
        # Assert children of multihead attention are not wrapped
        assert isinstance(model.module[1], nn.MultiheadAttention)
        assert isinstance(model.module[1].out_proj, nn.Linear)

119
    def test_auto_wrap_preset_force_leaf_custom(self):
120
        """
121
        Test to ensure force-leaf modules are not wrapped.
122
        """
123
124
125
126
127
128
129
130
131
132
133
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy,
            min_num_params=40,
            force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union({nn.Linear}),
        )
        with enable_wrap(
            auto_wrap_policy=my_auto_wrap_policy,
            wrapper_cls=FSDP,
            process_group=self.process_group,
            flatten_parameters=False,
        ):
134
            sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]))
135
            model = auto_wrap(sequential)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        # Model was wrapped in FSDP as no inner modules were wrapped.
        assert isinstance(model, FSDP)
        assert isinstance(model.module[0], nn.Linear)
        assert isinstance(model.module[1], nn.ModuleList)

    # todo: currently complains that address is in use, not sure why since I clear the proc group.
    # def test_auto_wrap_smoke(self):
    #     self._auto_wrap_smoke_test(enable_mixed_precision=False)

    def test_auto_wrap_smoke_autocast(self):
        """
        Ensure we can do a forward/backward through an auto-wrapped model.
        """
        self._auto_wrap_smoke_test(enable_mixed_precision=True)

    @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
152
    @unittest.skipIf(autocast is None, "Test Requires autocast")
153
154
155
    def _auto_wrap_smoke_test(self, enable_mixed_precision):
        device = torch.device("cuda")
        torch.cuda.set_device(0)
156
157
158
159

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
160
161
        torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        try:
            with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision):
                sequential = nn.Sequential(
                    nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
                )
                my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
                model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
            model.to(device)
            input = torch.rand((1, 5), dtype=torch.float).to(device)

            with autocast(enabled=enable_mixed_precision):
                output = model(input)
                loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]