test_fsdp_optimizer_utils.py 12.5 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
from time import time
7
import unittest
8
9
10

from parameterized import parameterized
import torch
Min Xu's avatar
Min Xu committed
11
from torch import nn
12
13
14
from torch.optim import SGD, Adadelta, Adam  # type: ignore

from fairscale.nn import FullyShardedDataParallel
15
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
16
from fairscale.utils.params import recursive_copy_to_device
17
from fairscale.utils.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes
18
19
20
21

from .test_fsdp import (
    DistributedTest,
    DummyProcessGroup,
22
    MixtureOfExperts,
23
24
25
26
27
28
    TransformerWithSharedParams,
    rename_test,
    spawn_and_init,
)


29
30
31
def all_tensors_numel_except_for_step(dct):
    """Compute the sum of numel from all tensors from a dict, except when the key is `step`."""
    ret = 0
32
    for k, v in dct.items():
33
34
35
        if k != "step" and torch.is_tensor(v):
            ret += v.numel()
    return ret
36
37
38
39
40
41


def assert_equal(a, b):
    assert a == b, f"{a} != {b}"


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def spawn_and_init_multiple_groups(fn, args=None, **spawn_kwargs):
    if args is None:
        args = ()

    run_fn = functools.partial(init_and_run, fn, args)
    spawn_for_all_world_sizes(run_fn, **spawn_kwargs)


def _find_my_group_index(grouped_ranks):
    """Return the index corresponding to the MoE group of the current process."""
    my_rank = torch.distributed.get_rank()
    for i, group in enumerate(grouped_ranks):
        if my_rank in group:
            return i
    raise RuntimeError(f"Unable to find process rank {my_rank} in the set of grouped ranks {grouped_ranks}.")


def get_moe_group(moe_expert_count=2):
    """Return a process group for initializing a MoE layer."""
    if torch.distributed.is_initialized():
        world_size = torch.distributed.get_world_size()

        # If you have more experts than the world size.
        if world_size <= moe_expert_count:
            assert moe_expert_count % world_size == 0
            moe_groups = [[i] for i in range(world_size)]

        # If you have a larger world size than experts.
        else:
            assert world_size % moe_expert_count == 0
            ranks_per_group = world_size // moe_expert_count
            moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)] for i in range(moe_expert_count)]

        moe_pgs = [torch.distributed.new_group(g) for g in moe_groups]

        # Find the index in the set of moe_groups which contains the current rank.
        my_group_idx = _find_my_group_index(moe_groups)
        return moe_pgs[my_group_idx]
    else:
        return torch.distributed.new_group([torch.distributed.get_rank()])


def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
    """Initialize and run the unit test for testing replicated MoE groups."""
    dist_init(rank, world_size, filename, filename_rpc)
    torch.cuda.set_device(rank)
    group = torch.distributed.new_group()
    # Specify the moe_group used to initialize the MoE layers with.
    fn(rank, group, *args, expert_group=get_moe_group())


93
94
class TestOptimizerUtils(DistributedTest):
    @parameterized.expand(
95
        [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]],
96
97
98
99
        name_func=rename_test,
    )
    def test_consolidate_optimizer(self, optim_fn, transformer):
        config = {"mixed_precision": True, "flatten_parameters": True}
100
        config["compute_dtype"] = torch.float32
101
102
103
104
105
106
        test_fn = functools.partial(
            self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer
        )

        spawn_and_init(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)])

107
108
109
110
111
112
113
114
115
116
117
118
119
    @parameterized.expand(
        [[SGD, False], [Adam, False]],
        name_func=rename_test,
    )
    def test_consolidate_optimizer_diff_world_size(self, optim_fn, transformer):
        if torch.cuda.device_count() < 4:
            raise unittest.SkipTest("This test requires at least 4 GPUs.")
        config = {"mixed_precision": True, "flatten_parameters": True}
        config["compute_dtype"] = torch.float32
        test_fn = functools.partial(self._test_consolidated_optimizer, config, optim_fn=Adam, transformer=transformer)

        spawn_and_init_multiple_groups(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)])

120
    @classmethod
121
122
123
    def _test_consolidated_optimizer(
        self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False, expert_group=None
    ):
124
125
126
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.
        if transformer:
127
            unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda()
128
129
            fsdp = self.get_wrapped_model(group, config=config).cuda()
        else:
130
131
132
133
            unwrapped_model = MixtureOfExperts(group, wrapper_config=None, expert_group=expert_group).cuda()
            fsdp = FullyShardedDataParallel(
                MixtureOfExperts(group, wrapper_config=config, expert_group=expert_group)
            ).cuda()
134
135

        try:
136
137
138
139
            fsdp_optim = optim_fn(
                fsdp.parameters(),
                lr=0.01,
            )
140
141
142
143
144
145
146
            optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
        except TypeError:  # Adadelta
            fsdp_optim = optim_fn(fsdp.parameters())
            optim_unwrapped = optim_fn(unwrapped_model.parameters())

        fsdp_optim.zero_grad()
        optim_unwrapped.zero_grad()
147
148
149
150
151
152
153
154
155
156
157
        with torch.cuda.amp.autocast(enabled=True):
            x = fsdp.module.get_input(torch.device("cuda"))
            output = fsdp(*x)
            loss = fsdp.module.get_loss(x, output).to("cuda")
            fsdp.module.run_backward(loss)
            fsdp_optim.step()

            output = unwrapped_model(*x)
            loss = unwrapped_model.get_loss(x, output)
            unwrapped_model.run_backward(loss)
            optim_unwrapped.step()
158
159
        unwrapped_sd = optim_unwrapped.state_dict()

160
        if not transformer and not expert_group:
161
            no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
162
            assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}"
163
            assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
164
165
        torch.cuda.empty_cache()
        cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
166
167
168
169
170
        tstart = time()
        sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
        duration = time() - tstart
        assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

171
172
173
174
175
        cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
        mem_usg_gb = cuda_gb_after - cuda_gb_before
        assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0"
        assert cuda_gb_after > 0, "got 0 memory usage, logging is broken"

176
        if fsdp.rank > 0:
177
            assert sd is None
178
            return
179
180
181
182
183
184
185
186

        # assert whole state dict on CPU
        for k, v in sd["state"].items():
            for buffer_name, t in v.items():
                if torch.is_tensor(t):
                    msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
                    assert t.device == torch.device("cpu"), msg

187
188
189
190
191
192
193
194
195
196
197
198
        if expert_group:
            sd_state = recursive_copy_to_device(sd["state"], non_blocking=False, device="cpu")
            orig_state = recursive_copy_to_device(unwrapped_sd["state"], non_blocking=False, device="cpu")

            assert_equal(len(sd_state.keys()), len(orig_state.keys()))

            assert_equal(
                sum([all_tensors_numel_except_for_step(v) for k, v in sd_state.items()]),
                sum([all_tensors_numel_except_for_step(v) for k, v in orig_state.items()]),
            )
            return

199
200
201
202
203
204
        unflat_state = sd["state"]
        assert "uncollected_local_ids" in sd
        shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
        shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu")
        state_after_get_shard = sd["state"]
        assert objects_are_equal(unflat_state, state_after_get_shard)  # no side effects.
205
206
207
208

        assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
        assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
        assert_equal(
209
210
            sum([all_tensors_numel_except_for_step(v) for k, v in sd["state"].items()]),
            sum([all_tensors_numel_except_for_step(v) for k, v in unwrapped_sd["state"].items()]),
211
212
213
214
215
216
        )

        original_shard_sd = fsdp_optim.state_dict()
        assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
        assert_equal(shard_sd.keys(), original_shard_sd.keys())
        original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu")
217
218
        # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
        assert_equal(
219
220
            [all_tensors_numel_except_for_step(v) for k, v in shard_sd["state"].items()],
            [all_tensors_numel_except_for_step(v) for k, v in original_shard_sd["state"].items()],
221
        )
222
        assert_equal(
223
224
            [v for k, v in shard_sd["param_groups"][0].items()],
            [v for k, v in original_shard_sd["param_groups"][0].items()],
225
        )
226
227
        assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
        assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)
228

Min Xu's avatar
Min Xu committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    @parameterized.expand(
        [(True,), (False,)],
        name_func=rename_test,
    )
    def test_model_with_unused_params(self, wrap_l2):
        """Test handling of model with unused params by gather_full_optim_state_dict()"""
        test_fn = functools.partial(self._test_model_with_unused_params, wrap_l2=wrap_l2)
        spawn_and_init(test_fn, world_sizes=[2])

    @classmethod
    def _test_model_with_unused_params(self, rank, pg, wrap_l2):
        model = ModelWithUnusedParams(wrap_l2).cuda()
        data = torch.rand(4).cuda().requires_grad_(True)
        model = FullyShardedDataParallel(model)
        optim = SGD(model.parameters(), momentum=0.9, lr=0.1)
        out = model(data).sum()
        out.backward()
        optim.step()
        model.zero_grad(set_to_none=True)
        sd = model.gather_full_optim_state_dict(optim)
        if rank == 0:
            shard_sd = model.get_shard_from_optim_state_dict(sd)
            orig_sd = optim.state_dict()
            orig_sd = recursive_copy_to_device(orig_sd, non_blocking=False, device="cpu")
            objects_are_equal(shard_sd, orig_sd, raise_exception=True)
        else:
            assert sd is None, sd

257
258
259
260
261
262
263
    def test_named_params_ordering(self):
        """Test assumption of consolidate_optimizer_state_dict"""
        group = DummyProcessGroup(0, 1)
        model = TransformerWithSharedParams(group)
        named_pars = [p for n, p in model.named_parameters()]
        for i, p in enumerate(model.parameters()):
            assert objects_are_equal(p, named_pars[i])
264
265

    def test_is_singleton_tensor(self):
Min Xu's avatar
Min Xu committed
266
        """Test is_singleton_tensor function"""
267
268
269
270
271
        assert is_singleton_tensor(torch.tensor(4.0))
        assert not is_singleton_tensor(torch.tensor([4.0]))
        assert not is_singleton_tensor(torch.tensor([4.0, 5.0]))
        assert not is_singleton_tensor([4.0])
        assert not is_singleton_tensor(4.0)
Min Xu's avatar
Min Xu committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295


class ModelWithUnusedParams(nn.Module):
    def __init__(self, wrap_l2):
        super().__init__()
        self.l = nn.Linear(4, 4)
        # unused param must be wrapped, otherwise, due to flatten, it
        # is always used.
        self.not_trained = nn.Linear(4, 4).requires_grad_(False)
        self.not_trained = FullyShardedDataParallel(self.not_trained)
        # optionally testing a used param after the unused one by
        # wrapping it.
        self.l2 = nn.Linear(4, 4)
        if wrap_l2:
            # When wrapping happens, the unused param will be in the middle
            # of the param list (for optimizer state dict), not at the
            # end. This way, we can test the handling code in more corner
            # cases.
            self.l2 = FullyShardedDataParallel(self.l2)

    def forward(self, x):
        with torch.no_grad():
            y = self.not_trained(x)
        return self.l2(self.l(x)) - y