test_fsdp_no_sync.py 8.62 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
7
import itertools
8
9
10
import unittest
from unittest.mock import patch

11
from parameterized import parameterized
12
13
14
15
16
import torch

from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, objects_are_equal

17
from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


class TestNoSync(DistributedTest):
    def test_transformer(self):
        fn = functools.partial(self._test_transformer, config={})
        spawn_and_init(fn)

    def test_transformer_no_flat_params(self):
        config = {"flatten_parameters": False}
        fn = functools.partial(self._test_transformer, config=config)
        spawn_and_init(fn)

    def test_nested_wrapper(self):
        fn = functools.partial(self._test_nested_wrapper, config={})
        spawn_and_init(fn)

    def test_no_sync_before_first_forward(self):
        group = DummyProcessGroup(rank=0, size=1)
        model = self.get_wrapped_model(group, config={}, add_bn=False)
        batch = model.module.get_input(torch.device("cuda"))
        with model.no_sync():
            output = model(*batch)
            loss = model.module.get_loss(batch, output)
            loss.backward()
        output = model(*batch)
        loss = model.module.get_loss(batch, output)
        loss.backward()

    @classmethod
    def _test_transformer(self, rank, group, config):
        model = self.get_wrapped_model(group, config=config, add_bn=False)
        model.eval()  # turn off dropout for the test
        self._test_no_sync(model, batch_dim=1)

    @classmethod
    def _test_nested_wrapper(self, rank, group, config):
        model = NestedWrappedModule(group, config)
        model = FullyShardedDataParallel(model, group, **config).cuda()
        self._test_no_sync(model, batch_dim=0)

    @classmethod
    def _test_no_sync(self, model, batch_dim):
        # Generate two input batches. We'll test that we get the same grads if
        # we train on them sequentially while accumulating grads (with no_sync)
        # vs. concatenating the batches and training in one go.
        batch1 = model.module.get_input(torch.device("cuda"))
        assert isinstance(batch1, tuple)
        batch2 = tuple(
            # This randomly permutes the values in a multi-dim tensor.
            x.view(-1)[torch.randperm(x.numel())].view_as(x)
            for x in batch1
        )
        for x, y in zip(batch1, batch2):
            assert not torch.all(x == y)

        # Concat the batches along batch dimension.
        concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2))

        # Establish reference behavior on the concat batch.
        model.zero_grad()
        output = model(*concat_batch)
        ref_loss = model.module.get_loss(concat_batch, output)
        ref_loss.backward()
        ref_grads = [p.grad.detach().clone() for p in model.parameters()]

        # Test that we get the same results by accumulating grads.
        model.zero_grad()
        with model.no_sync():  # accumulate gradients from the first batch
            output = model(*batch1)
            loss1 = model.module.get_loss(batch1, output)
            loss1.backward()
        output = model(*batch2)
        loss2 = model.module.get_loss(batch2, output)
        loss2.backward()
        accumulated_loss = loss1 + loss2
        accumulated_grads = [p.grad.detach().clone() for p in model.parameters()]

        torch.testing.assert_allclose(ref_loss, accumulated_loss)
        assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)


99
100
101
102
keys = ["reshard_after_forward", "mixed_precision"]
COMM_CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]


103
class TestNoSyncCommunication(DistributedTest):
104
105
    @parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test)
    def test_communication(self, config):
106
107
108
        fn = functools.partial(self._test_communication, config=config)
        spawn_and_init(fn)

109
110
111
112
113
    @parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test)
    def test_communication_nested(self, config):
        fn = functools.partial(self._test_communication, config=config, nested_model=True)
        spawn_and_init(fn)

114
    @classmethod
115
    def _test_communication(self, rank, group, config, nested_model=False):
116
117
118
        if group.size() == 1:
            return

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        # Turn off bucketing to accurately count number of reduce_scatters.
        config["bucket_cap_mb"] = 0

        if nested_model:
            model = NestedWrappedModule(group, config)
            model = FullyShardedDataParallel(model, group, **config).cuda()
        else:
            model = self.get_wrapped_model(group, config=config)

        num_fsdp = 0
        for child in model.modules():  # includes self
            if isinstance(child, FullyShardedDataParallel) and len(child.params) > 0:
                num_fsdp += 1

        if config.get("reshard_after_forward", True):
            # inside no_sync:
            #   num_fsdp all-gathers in the forward
            #   num_fsdp-1 all-gathers in the backward (except root)
            # outside no_sync:
            #   num_fsdp-1 all-gathers in the forward (except root)
            #   num_fsdp-1 all-gathers in the backward (except root)
            expected_all_gather1 = 2 * num_fsdp - 1
            expected_all_gather2 = expected_all_gather1 + (2 * num_fsdp - 2)
        else:
            # inside no_sync:
            #   num_fsdp all-gathers in the forward
            # outside no_sync:
            #   none
            expected_all_gather1 = num_fsdp
            expected_all_gather2 = num_fsdp

        expected_reduce_scatter = num_fsdp
151
152

        batch = model.module.get_input(torch.device("cuda"))
153
154
155
156
157
158
159
160
161
162
163
164

        # depending on pytorch version the _base methods may not be available
        method_string_reduce_scatter_base = "torch.distributed._reduce_scatter_base"
        if hasattr(torch.distributed, "_reduce_scatter_base") is False:
            # no such method, to make mock_reduce_scatter_base 0 invocation, use an impossible name
            method_string_reduce_scatter_base = "math.nan"  # just an arbitrary function not going to be called

        method_string_all_gather_base = "torch.distributed._all_gather_base"
        if hasattr(torch.distributed, "_all_gather_base") is False:
            # no such method, to make mock_all_gather_base 0 invocation, use an impossible name
            method_string_all_gather_base = "math.nan"  # just an arbitrary function not going to be called

165
166
        with patch("torch.distributed.all_gather") as mock_all_gather:
            with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter:
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
                with patch(method_string_all_gather_base) as mock_all_gather_base:
                    with patch(method_string_reduce_scatter_base) as mock_reduce_scatter_base:
                        with model.no_sync():
                            output = model(*batch)
                            loss = model.module.get_loss(batch, output)
                            loss.backward()

                        # the _base methods are activated when they are available.
                        # the sum of the _base and public methods should stay the same.
                        assert (
                            mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather1
                        ), f"{mock_all_gather.call_count +  mock_all_gather_base.call_count} != {expected_all_gather1}"
                        assert (
                            mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count == 0
                        ), f"{mock_reduce_scatter.call_count +  mock_reduce_scatter_base.call_count} != 0"

                        output = model(*batch)
                        loss = model.module.get_loss(batch, output)
                        loss.backward()

                        assert (
                            mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather2
                        ), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather2}"
                        assert (
                            mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count
                            == expected_reduce_scatter
                        ), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != {expected_reduce_scatter}"
194
195
196
197


if __name__ == "__main__":
    unittest.main()