test_fsdp_optimizer_utils.py 10.3 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
from time import time
7
import unittest
8
9
10
11
12
13

from parameterized import parameterized
import torch
from torch.optim import SGD, Adadelta, Adam  # type: ignore

from fairscale.nn import FullyShardedDataParallel
14
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
15
from fairscale.utils.params import recursive_copy_to_device
16
from fairscale.utils.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes
17
18
19
20

from .test_fsdp import (
    DistributedTest,
    DummyProcessGroup,
21
    MixtureOfExperts,
22
23
24
25
26
27
    TransformerWithSharedParams,
    rename_test,
    spawn_and_init,
)


28
29
30
def all_tensors_numel_except_for_step(dct):
    """Compute the sum of numel from all tensors from a dict, except when the key is `step`."""
    ret = 0
31
    for k, v in dct.items():
32
33
34
        if k != "step" and torch.is_tensor(v):
            ret += v.numel()
    return ret
35
36
37
38
39
40


def assert_equal(a, b):
    assert a == b, f"{a} != {b}"


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def spawn_and_init_multiple_groups(fn, args=None, **spawn_kwargs):
    if args is None:
        args = ()

    run_fn = functools.partial(init_and_run, fn, args)
    spawn_for_all_world_sizes(run_fn, **spawn_kwargs)


def _find_my_group_index(grouped_ranks):
    """Return the index corresponding to the MoE group of the current process."""
    my_rank = torch.distributed.get_rank()
    for i, group in enumerate(grouped_ranks):
        if my_rank in group:
            return i
    raise RuntimeError(f"Unable to find process rank {my_rank} in the set of grouped ranks {grouped_ranks}.")


def get_moe_group(moe_expert_count=2):
    """Return a process group for initializing a MoE layer."""
    if torch.distributed.is_initialized():
        world_size = torch.distributed.get_world_size()

        # If you have more experts than the world size.
        if world_size <= moe_expert_count:
            assert moe_expert_count % world_size == 0
            moe_groups = [[i] for i in range(world_size)]

        # If you have a larger world size than experts.
        else:
            assert world_size % moe_expert_count == 0
            ranks_per_group = world_size // moe_expert_count
            moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)] for i in range(moe_expert_count)]

        moe_pgs = [torch.distributed.new_group(g) for g in moe_groups]

        # Find the index in the set of moe_groups which contains the current rank.
        my_group_idx = _find_my_group_index(moe_groups)
        return moe_pgs[my_group_idx]
    else:
        return torch.distributed.new_group([torch.distributed.get_rank()])


def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
    """Initialize and run the unit test for testing replicated MoE groups."""
    dist_init(rank, world_size, filename, filename_rpc)
    torch.cuda.set_device(rank)
    group = torch.distributed.new_group()
    # Specify the moe_group used to initialize the MoE layers with.
    fn(rank, group, *args, expert_group=get_moe_group())


92
93
class TestOptimizerUtils(DistributedTest):
    @parameterized.expand(
94
        [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]],
95
96
97
98
        name_func=rename_test,
    )
    def test_consolidate_optimizer(self, optim_fn, transformer):
        config = {"mixed_precision": True, "flatten_parameters": True}
99
        config["compute_dtype"] = torch.float32
100
101
102
103
104
105
        test_fn = functools.partial(
            self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer
        )

        spawn_and_init(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)])

106
107
108
109
110
111
112
113
114
115
116
117
118
    @parameterized.expand(
        [[SGD, False], [Adam, False]],
        name_func=rename_test,
    )
    def test_consolidate_optimizer_diff_world_size(self, optim_fn, transformer):
        if torch.cuda.device_count() < 4:
            raise unittest.SkipTest("This test requires at least 4 GPUs.")
        config = {"mixed_precision": True, "flatten_parameters": True}
        config["compute_dtype"] = torch.float32
        test_fn = functools.partial(self._test_consolidated_optimizer, config, optim_fn=Adam, transformer=transformer)

        spawn_and_init_multiple_groups(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)])

119
    @classmethod
120
121
122
    def _test_consolidated_optimizer(
        self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False, expert_group=None
    ):
123
124
125
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.
        if transformer:
126
            unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda()
127
128
            fsdp = self.get_wrapped_model(group, config=config).cuda()
        else:
129
130
131
132
            unwrapped_model = MixtureOfExperts(group, wrapper_config=None, expert_group=expert_group).cuda()
            fsdp = FullyShardedDataParallel(
                MixtureOfExperts(group, wrapper_config=config, expert_group=expert_group)
            ).cuda()
133
134

        try:
135
136
137
138
            fsdp_optim = optim_fn(
                fsdp.parameters(),
                lr=0.01,
            )
139
140
141
142
143
144
145
            optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
        except TypeError:  # Adadelta
            fsdp_optim = optim_fn(fsdp.parameters())
            optim_unwrapped = optim_fn(unwrapped_model.parameters())

        fsdp_optim.zero_grad()
        optim_unwrapped.zero_grad()
146
147
148
149
150
151
152
153
154
155
156
        with torch.cuda.amp.autocast(enabled=True):
            x = fsdp.module.get_input(torch.device("cuda"))
            output = fsdp(*x)
            loss = fsdp.module.get_loss(x, output).to("cuda")
            fsdp.module.run_backward(loss)
            fsdp_optim.step()

            output = unwrapped_model(*x)
            loss = unwrapped_model.get_loss(x, output)
            unwrapped_model.run_backward(loss)
            optim_unwrapped.step()
157
158
        unwrapped_sd = optim_unwrapped.state_dict()

159
        if not transformer and not expert_group:
160
            no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
161
            assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}"
162
            assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
163
164
        torch.cuda.empty_cache()
        cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
165
166
167
168
169
        tstart = time()
        sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
        duration = time() - tstart
        assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

170
171
172
173
174
        cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
        mem_usg_gb = cuda_gb_after - cuda_gb_before
        assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0"
        assert cuda_gb_after > 0, "got 0 memory usage, logging is broken"

175
        if fsdp.rank > 0:
176
            assert sd is None
177
            return
178
179
180
181
182
183
184
185

        # assert whole state dict on CPU
        for k, v in sd["state"].items():
            for buffer_name, t in v.items():
                if torch.is_tensor(t):
                    msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
                    assert t.device == torch.device("cpu"), msg

186
187
188
189
190
191
192
193
194
195
196
197
        if expert_group:
            sd_state = recursive_copy_to_device(sd["state"], non_blocking=False, device="cpu")
            orig_state = recursive_copy_to_device(unwrapped_sd["state"], non_blocking=False, device="cpu")

            assert_equal(len(sd_state.keys()), len(orig_state.keys()))

            assert_equal(
                sum([all_tensors_numel_except_for_step(v) for k, v in sd_state.items()]),
                sum([all_tensors_numel_except_for_step(v) for k, v in orig_state.items()]),
            )
            return

198
199
200
201
202
203
        unflat_state = sd["state"]
        assert "uncollected_local_ids" in sd
        shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
        shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu")
        state_after_get_shard = sd["state"]
        assert objects_are_equal(unflat_state, state_after_get_shard)  # no side effects.
204
205
206
207

        assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
        assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
        assert_equal(
208
209
            sum([all_tensors_numel_except_for_step(v) for k, v in sd["state"].items()]),
            sum([all_tensors_numel_except_for_step(v) for k, v in unwrapped_sd["state"].items()]),
210
211
212
213
214
215
        )

        original_shard_sd = fsdp_optim.state_dict()
        assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
        assert_equal(shard_sd.keys(), original_shard_sd.keys())
        original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu")
216
217
        # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
        assert_equal(
218
219
            [all_tensors_numel_except_for_step(v) for k, v in shard_sd["state"].items()],
            [all_tensors_numel_except_for_step(v) for k, v in original_shard_sd["state"].items()],
220
        )
221
        assert_equal(
222
223
            [v for k, v in shard_sd["param_groups"][0].items()],
            [v for k, v in original_shard_sd["param_groups"][0].items()],
224
        )
225
226
        assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
        assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)
227
228
229
230
231
232
233
234

    def test_named_params_ordering(self):
        """Test assumption of consolidate_optimizer_state_dict"""
        group = DummyProcessGroup(0, 1)
        model = TransformerWithSharedParams(group)
        named_pars = [p for n, p in model.named_parameters()]
        for i, p in enumerate(model.parameters()):
            assert objects_are_equal(p, named_pars[i])
235
236
237
238
239
240
241

    def test_is_singleton_tensor(self):
        assert is_singleton_tensor(torch.tensor(4.0))
        assert not is_singleton_tensor(torch.tensor([4.0]))
        assert not is_singleton_tensor(torch.tensor([4.0, 5.0]))
        assert not is_singleton_tensor([4.0])
        assert not is_singleton_tensor(4.0)