test_wrap.py 7.15 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

6
import functools
7
import os
8
import random
9
10
11
12
13
14
15
import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F

from fairscale.nn import FullyShardedDataParallel as FSDP
16
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
17
18
from fairscale.utils.testing import DummyProcessGroup

19
20
21
22
23
try:
    from torch.cuda.amp import autocast
except ImportError:
    autocast = None  # type: ignore

24
25
26
27
28
29

class TestAutoWrap(unittest.TestCase):
    def setUp(self) -> None:
        self.process_group = DummyProcessGroup(rank=0, size=1)

    def test_wrap(self):
30
        with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
31
32
33
34
35
36
37
38
39
            layer = wrap(nn.Linear(5, 5))
        assert isinstance(layer, FSDP)
        assert layer.flatten_parameters is False

    def test_wrap_disabled_outside_context(self):
        layer = wrap(nn.Linear(5, 5))
        assert isinstance(layer, nn.Linear)

    def test_wrap_override_defaults(self):
40
        with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
41
42
43
44
45
46
47
48
49
            layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
        assert isinstance(layer, FSDP)
        assert layer.flatten_parameters

    def test_auto_wrap(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
50
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
51
52
53
            sequential = nn.Sequential(
                nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
            )
54
55
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
56
57
58
59
60
61
62
63
64
65
66
67
        assert isinstance(model, FSDP)
        assert isinstance(model.module[0], nn.Linear)
        assert isinstance(model.module[1], nn.Linear)
        assert isinstance(model.module[2], FSDP)
        assert isinstance(model.module[2].module[0], nn.Linear)
        assert isinstance(model.module[2].module[1], nn.Linear)

    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params.
        """
68
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
69
            sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
70
71
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
72
73
74
75
76
77
78
79
80
        assert isinstance(model, nn.ModuleList)
        assert isinstance(model[0], nn.Linear)
        assert isinstance(model[1], nn.Linear)

    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
81
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
82
            sequential = nn.ModuleList([nn.Linear(10, 10)])
83
84
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
85
86
87
        assert isinstance(model, nn.ModuleList)
        assert isinstance(model[0], FSDP)

88
    def test_auto_wrap_preset_force_leaf(self):
89
        """
90
        Test to ensure force-leaf modules are not wrapped, and children are not wrapped.
91
        """
92
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
93
            sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
94
95
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
96
97
98
99
100
        assert isinstance(model.module[0], FSDP)
        # Assert children of multihead attention are not wrapped
        assert isinstance(model.module[1], nn.MultiheadAttention)
        assert isinstance(model.module[1].out_proj, nn.Linear)

101
    def test_auto_wrap_preset_force_leaf_custom(self):
102
        """
103
        Test to ensure force-leaf modules are not wrapped.
104
        """
105
106
107
108
109
110
111
112
113
114
115
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy,
            min_num_params=40,
            force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union({nn.Linear}),
        )
        with enable_wrap(
            auto_wrap_policy=my_auto_wrap_policy,
            wrapper_cls=FSDP,
            process_group=self.process_group,
            flatten_parameters=False,
        ):
116
            sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]))
117
            model = auto_wrap(sequential)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        # Model was wrapped in FSDP as no inner modules were wrapped.
        assert isinstance(model, FSDP)
        assert isinstance(model.module[0], nn.Linear)
        assert isinstance(model.module[1], nn.ModuleList)

    # todo: currently complains that address is in use, not sure why since I clear the proc group.
    # def test_auto_wrap_smoke(self):
    #     self._auto_wrap_smoke_test(enable_mixed_precision=False)

    def test_auto_wrap_smoke_autocast(self):
        """
        Ensure we can do a forward/backward through an auto-wrapped model.
        """
        self._auto_wrap_smoke_test(enable_mixed_precision=True)

    @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
134
    @unittest.skipIf(autocast is None, "Test Requires autocast")
135
136
137
    def _auto_wrap_smoke_test(self, enable_mixed_precision):
        device = torch.device("cuda")
        torch.cuda.set_device(0)
138
139
140
141

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
142
143
        torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        try:
            with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision):
                sequential = nn.Sequential(
                    nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
                )
                my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
                model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
            model.to(device)
            input = torch.rand((1, 5), dtype=torch.float).to(device)

            with autocast(enabled=enable_mixed_precision):
                output = model(input)
                loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]