test_fsdp_no_sync.py 4.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import unittest
from unittest.mock import patch

import torch

from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, objects_are_equal

from .test_fsdp import DistributedTest, NestedWrappedModule, spawn_and_init


class TestNoSync(DistributedTest):
    def test_transformer(self):
        fn = functools.partial(self._test_transformer, config={})
        spawn_and_init(fn)

    def test_transformer_no_flat_params(self):
        config = {"flatten_parameters": False}
        fn = functools.partial(self._test_transformer, config=config)
        spawn_and_init(fn)

    def test_nested_wrapper(self):
        fn = functools.partial(self._test_nested_wrapper, config={})
        spawn_and_init(fn)

    def test_no_sync_before_first_forward(self):
        group = DummyProcessGroup(rank=0, size=1)
        model = self.get_wrapped_model(group, config={}, add_bn=False)
        batch = model.module.get_input(torch.device("cuda"))
        with model.no_sync():
            output = model(*batch)
            loss = model.module.get_loss(batch, output)
            loss.backward()
        output = model(*batch)
        loss = model.module.get_loss(batch, output)
        loss.backward()

    @classmethod
    def _test_transformer(self, rank, group, config):
        model = self.get_wrapped_model(group, config=config, add_bn=False)
        model.eval()  # turn off dropout for the test
        self._test_no_sync(model, batch_dim=1)

    @classmethod
    def _test_nested_wrapper(self, rank, group, config):
        model = NestedWrappedModule(group, config)
        model = FullyShardedDataParallel(model, group, **config).cuda()
        self._test_no_sync(model, batch_dim=0)

    @classmethod
    def _test_no_sync(self, model, batch_dim):
        # Generate two input batches. We'll test that we get the same grads if
        # we train on them sequentially while accumulating grads (with no_sync)
        # vs. concatenating the batches and training in one go.
        batch1 = model.module.get_input(torch.device("cuda"))
        assert isinstance(batch1, tuple)
        batch2 = tuple(
            # This randomly permutes the values in a multi-dim tensor.
            x.view(-1)[torch.randperm(x.numel())].view_as(x)
            for x in batch1
        )
        for x, y in zip(batch1, batch2):
            assert not torch.all(x == y)

        # Concat the batches along batch dimension.
        concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2))

        # Establish reference behavior on the concat batch.
        model.zero_grad()
        output = model(*concat_batch)
        ref_loss = model.module.get_loss(concat_batch, output)
        ref_loss.backward()
        ref_grads = [p.grad.detach().clone() for p in model.parameters()]

        # Test that we get the same results by accumulating grads.
        model.zero_grad()
        with model.no_sync():  # accumulate gradients from the first batch
            output = model(*batch1)
            loss1 = model.module.get_loss(batch1, output)
            loss1.backward()
        output = model(*batch2)
        loss2 = model.module.get_loss(batch2, output)
        loss2.backward()
        accumulated_loss = loss1 + loss2
        accumulated_grads = [p.grad.detach().clone() for p in model.parameters()]

        torch.testing.assert_allclose(ref_loss, accumulated_loss)
        assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)


class TestNoSyncCommunication(DistributedTest):
    def test_communication(self):
        config = {"mixed_precision": True}
        fn = functools.partial(self._test_communication, config=config)
        spawn_and_init(fn)

    @classmethod
    def _test_communication(self, rank, group, config):
        if group.size() == 1:
            return

        model = self.get_wrapped_model(group, config=config)

        batch = model.module.get_input(torch.device("cuda"))

        with patch("torch.distributed.all_gather") as mock_all_gather:
            with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter:
                with model.no_sync():
                    output = model(*batch)
                    loss = model.module.get_loss(batch, output)
                    loss.backward()

                assert mock_all_gather.call_count == 1
                assert mock_reduce_scatter.call_count == 0

                output = model(*batch)
                loss = model.module.get_loss(batch, output)
                loss.backward()

                assert mock_all_gather.call_count == 1
                assert mock_reduce_scatter.call_count == 1


if __name__ == "__main__":
    unittest.main()