"vscode:/vscode.git/clone" did not exist on "33f0de337d978b37c63b98575b4962c6e6479e8c"
test_fsdp_grad_acc.py 9.31 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
import contextlib
7
import functools
8
import itertools
9
10
11
import unittest
from unittest.mock import patch

12
from parameterized import parameterized
13
14
15
16
17
import torch

from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, objects_are_equal

18
from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init
19
20


21
class TestGradAcc(DistributedTest):
22
23
24
25
    def test_transformer(self):
        fn = functools.partial(self._test_transformer, config={})
        spawn_and_init(fn)

26
27
28
29
    def test_transformer_grad_acc_without_no_sync(self):
        fn = functools.partial(self._test_transformer, config={}, use_no_sync_context=False)
        spawn_and_init(fn)

30
31
32
33
34
35
36
37
38
39
40
    def test_transformer_no_flat_params(self):
        config = {"flatten_parameters": False}
        fn = functools.partial(self._test_transformer, config=config)
        spawn_and_init(fn)

    def test_nested_wrapper(self):
        fn = functools.partial(self._test_nested_wrapper, config={})
        spawn_and_init(fn)

    def test_no_sync_before_first_forward(self):
        group = DummyProcessGroup(rank=0, size=1)
41
42
43
        dummy_group_reduce_scatter = DummyProcessGroup(rank=group.rank(), size=group.size())
        config = {"process_group_reduce_scatter", dummy_group_reduce_scatter}
        model = self.get_wrapped_model(group, config, add_bn=False)
44
45
46
47
48
49
50
51
52
53
        batch = model.module.get_input(torch.device("cuda"))
        with model.no_sync():
            output = model(*batch)
            loss = model.module.get_loss(batch, output)
            loss.backward()
        output = model(*batch)
        loss = model.module.get_loss(batch, output)
        loss.backward()

    @classmethod
54
    def _test_transformer(self, rank, group, config, use_no_sync_context=True):
55
56
        model = self.get_wrapped_model(group, config=config, add_bn=False)
        model.eval()  # turn off dropout for the test
57
        self._test_grad_acc(model, batch_dim=1, use_no_sync_context=use_no_sync_context)
58
59
60
61
62

    @classmethod
    def _test_nested_wrapper(self, rank, group, config):
        model = NestedWrappedModule(group, config)
        model = FullyShardedDataParallel(model, group, **config).cuda()
63
        self._test_grad_acc(model, batch_dim=0)
64
65

    @classmethod
66
    def _test_grad_acc(self, model, batch_dim, use_no_sync_context=True):
67
        # Generate two input batches. We'll test that we get the same grads if
68
69
70
71
72
        # we train on them sequentially while accumulating grads (with no_sync
        # or without no_sync) vs. concatenating the batches and training in one go.
        #
        # The difference between with no_sync and without is GPU memory vs. networking
        # bandwidth tradeoff.
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        batch1 = model.module.get_input(torch.device("cuda"))
        assert isinstance(batch1, tuple)
        batch2 = tuple(
            # This randomly permutes the values in a multi-dim tensor.
            x.view(-1)[torch.randperm(x.numel())].view_as(x)
            for x in batch1
        )
        for x, y in zip(batch1, batch2):
            assert not torch.all(x == y)

        # Concat the batches along batch dimension.
        concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2))

        # Establish reference behavior on the concat batch.
        model.zero_grad()
        output = model(*concat_batch)
        ref_loss = model.module.get_loss(concat_batch, output)
        ref_loss.backward()
        ref_grads = [p.grad.detach().clone() for p in model.parameters()]

        # Test that we get the same results by accumulating grads.
        model.zero_grad()
95
96
97
98
        context = contextlib.suppress()
        if use_no_sync_context:
            context = model.no_sync()
        with context:  # accumulate gradients from the first batch
99
100
101
102
103
104
105
106
107
108
109
110
111
            output = model(*batch1)
            loss1 = model.module.get_loss(batch1, output)
            loss1.backward()
        output = model(*batch2)
        loss2 = model.module.get_loss(batch2, output)
        loss2.backward()
        accumulated_loss = loss1 + loss2
        accumulated_grads = [p.grad.detach().clone() for p in model.parameters()]

        torch.testing.assert_allclose(ref_loss, accumulated_loss)
        assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)


112
113
114
115
keys = ["reshard_after_forward", "mixed_precision"]
COMM_CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]


116
class TestGradAccCommunication(DistributedTest):
117
118
    @parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test)
    def test_communication(self, config):
119
120
121
        fn = functools.partial(self._test_communication, config=config)
        spawn_and_init(fn)

122
123
124
125
126
    @parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test)
    def test_communication_nested(self, config):
        fn = functools.partial(self._test_communication, config=config, nested_model=True)
        spawn_and_init(fn)

127
    @classmethod
128
    def _test_communication(self, rank, group, config, nested_model=False):
129
130
131
        if group.size() == 1:
            return

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        # Turn off bucketing to accurately count number of reduce_scatters.
        config["bucket_cap_mb"] = 0

        if nested_model:
            model = NestedWrappedModule(group, config)
            model = FullyShardedDataParallel(model, group, **config).cuda()
        else:
            model = self.get_wrapped_model(group, config=config)

        num_fsdp = 0
        for child in model.modules():  # includes self
            if isinstance(child, FullyShardedDataParallel) and len(child.params) > 0:
                num_fsdp += 1

        if config.get("reshard_after_forward", True):
            # inside no_sync:
            #   num_fsdp all-gathers in the forward
            #   num_fsdp-1 all-gathers in the backward (except root)
            # outside no_sync:
            #   num_fsdp-1 all-gathers in the forward (except root)
            #   num_fsdp-1 all-gathers in the backward (except root)
            expected_all_gather1 = 2 * num_fsdp - 1
            expected_all_gather2 = expected_all_gather1 + (2 * num_fsdp - 2)
        else:
            # inside no_sync:
            #   num_fsdp all-gathers in the forward
            # outside no_sync:
            #   none
            expected_all_gather1 = num_fsdp
            expected_all_gather2 = num_fsdp

        expected_reduce_scatter = num_fsdp
164
165

        batch = model.module.get_input(torch.device("cuda"))
166
167
168
169
170
171
172
173
174
175
176
177

        # depending on pytorch version the _base methods may not be available
        method_string_reduce_scatter_base = "torch.distributed._reduce_scatter_base"
        if hasattr(torch.distributed, "_reduce_scatter_base") is False:
            # no such method, to make mock_reduce_scatter_base 0 invocation, use an impossible name
            method_string_reduce_scatter_base = "math.nan"  # just an arbitrary function not going to be called

        method_string_all_gather_base = "torch.distributed._all_gather_base"
        if hasattr(torch.distributed, "_all_gather_base") is False:
            # no such method, to make mock_all_gather_base 0 invocation, use an impossible name
            method_string_all_gather_base = "math.nan"  # just an arbitrary function not going to be called

178
179
        with patch("torch.distributed.all_gather") as mock_all_gather:
            with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter:
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
                with patch(method_string_all_gather_base) as mock_all_gather_base:
                    with patch(method_string_reduce_scatter_base) as mock_reduce_scatter_base:
                        with model.no_sync():
                            output = model(*batch)
                            loss = model.module.get_loss(batch, output)
                            loss.backward()

                        # the _base methods are activated when they are available.
                        # the sum of the _base and public methods should stay the same.
                        assert (
                            mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather1
                        ), f"{mock_all_gather.call_count +  mock_all_gather_base.call_count} != {expected_all_gather1}"
                        assert (
                            mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count == 0
                        ), f"{mock_reduce_scatter.call_count +  mock_reduce_scatter_base.call_count} != 0"

                        output = model(*batch)
                        loss = model.module.get_loss(batch, output)
                        loss.backward()

                        assert (
                            mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather2
                        ), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather2}"
                        assert (
                            mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count
                            == expected_reduce_scatter
                        ), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != {expected_reduce_scatter}"
207
208
209
210


if __name__ == "__main__":
    unittest.main()