test_fsdp_offload.py 17.7 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import itertools
import sys
9
import tempfile
10
11
12
13
14
15
16
17
18
19
import time
import unittest

from parameterized import parameterized
import pytest
import torch
from torch import nn
import torch.distributed

from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
20
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
21
from fairscale.utils import torch_version
22
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
23
24

# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
25
print(f"torch version {torch_version()}")
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")


# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod


class DistributedTest(unittest.TestCase):
    def setUp(self):
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA not available, skipping test")
        if sys.platform == "win32":
            raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
        if torch.cuda.device_count() < 2:
            raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")

    @staticmethod
    def _eval_with_config(model, autocast):
        model.eval()
        model_device = torch.device("cuda")
        with torch.cuda.amp.autocast(enabled=autocast):
            # Inputs always cuda regardless of move_grads_cpu, or model.device
            input = model.module.get_input(torch.device("cuda"))
            output = model(*input)
            loss = model.module.get_loss(input, output).to(model_device)
        assert loss.dtype == torch.float32
        if isinstance(model, FullyShardedDataParallel):
            model.assert_state(TrainingState.IDLE)
        return loss.detach()

    @staticmethod
    def _eval_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None):
        model.eval()
        # Inputs always cuda regardless of move_grads_cpu, or model.device
        input = model.module.get_input(torch.device("cuda"))

        for _ in range(num_steps):
            with torch.cuda.amp.autocast(enabled=autocast):
                output = model(*input)

    @classmethod
    def _test_identical_outputs_eval(
68
69
70
71
72
73
74
75
76
        cls,
        model_init_fn,
        config,
        rank,
        group,
        num_steps=2,
        use_cuda=True,
        lr=0.01,
        ref_ddp_fn=None,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    ):
        if config.get("mixed_precision", False):
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank, process_group=group
            )
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._eval_with_config(model, autocast)
        ref_state_dict = model.module.state_dict()
        if config.get("cpu_offload", False):
            for k in ref_state_dict.keys():
                ref_state_dict[k] = ref_state_dict[k].cpu()

        # Confirm we get the same behavior using FullyShardedDataParallel.
103
104
        if config.get("ssd_offload", False):
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
105
106
            # ssd offload only supports flatten_params ATM
            config["flatten_parameters"] = True
107
108

        del config["ssd_offload"]
109
        model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
110
        if not model.ssd_offload and not model.move_params_to_cpu:
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            if use_cuda:
                model = model.cuda()
            else:
                assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._eval_with_config(model, autocast)

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
        except (AssertionError, RuntimeError) as e:
            raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
        if config.get("flatten_parameters", True):
            metadata = model.local_metadata_dict()
            assert isinstance(metadata, dict)


126
keys = ["reshard_after_forward", "mixed_precision", "nested_wrapping"]
127
128
129
130
CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]


def rename_test(testcase_func, param_num, param):
131
132
133
134
    return "%s_%s" % (
        testcase_func.__name__,
        parameterized.to_safe_name(str(param.args)),
    )
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155


class TestSsdMemory(DistributedTest):
    def test_memory_benchmark(self):
        test_fn = functools.partial(self._test_memory_benchmark, config={})
        spawn_and_init(test_fn)

    @classmethod
    def _test_memory_benchmark(self, rank, group, config):
        time_keeper = TimeKeeper()

        SIZE = 8 * 8
        time_keeper.print_time("START", 1.0)
        a = torch.empty(1)
        b = a.cuda()
        # wait for cuda to fully load
        time.sleep(1)
        time_keeper.print_time("INIT_CUDA", 1.0)
        model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
        time_keeper.print_time("CPU_MODEL", 1.0)

156
        with tempfile.TemporaryDirectory() as current_tempdir:
157
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
158

159
160
            model = FullyShardedDataParallel(model, **config)
            time_keeper.print_time("FSDP_MODEL", 1.0)
161

162
163
            self._eval_for_several_steps(model, 1, autocast=False)
            time_keeper.print_time("EVAL")
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225


class SimpleLinear(nn.Module):
    def __init__(self, group, input_size, output_size, layers=1, **unused_kwargs):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        self.input_size = input_size
        self.output_size = output_size
        torch.manual_seed(0)  # keep everything deterministic
        seq_layers = []
        for i in range(layers):
            seq_layers.append(nn.Linear(input_size, output_size, bias=False))
        self.module = nn.Sequential(*seq_layers)
        self.bs = 2

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        src = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
        tgt = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
        return (src, tgt)

    def forward(self, src_ids, tgt_ids):
        param_devices = [p.device for p in self.module.parameters()]

        return self.module(src_ids)

    def get_loss(self, input, output):
        _, tgt = input

        return nn.functional.binary_cross_entropy_with_logits(output, tgt)

    def run_backward(self, loss):
        loss.backward()


KEYS = ["ssd_offload", "flatten_parameters", "mixed_precision", "move_params_to_cpu"]
CONFIG = [[dict(zip(KEYS, config))] for config in itertools.product([True, False], repeat=len(KEYS))]


class TimeKeeper:
    def __init__(self):
        self.start_time = time.time()

    def print_time(self, s: str, wait_time: float = 1.0):
        cur_time = time.time()
        print(f"@time: {cur_time - self.start_time:0.2f} {s}")
        time.sleep(wait_time)


class TestModuleProperties(DistributedTest):
    @parameterized.expand(CONFIG, name_func=rename_test)
    def test_named_parameters(self, config):
        test_fn = functools.partial(self._test_named_params, config=config)
        spawn_and_init(test_fn)

    @classmethod
    def _test_named_params(self, rank, group, config):
        # Get the named parameters before wrapping.
        before_wrap_model = TransformerWithSharedParams(group)
        before_wrap_params = before_wrap_model.named_parameters()

226
227
        with tempfile.TemporaryDirectory() as current_tempdir:
            if config["ssd_offload"]:
228
229
230
                config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
                # ssd offload only supports flatten_params ATM
                config["flatten_parameters"] = True
231
            del config["ssd_offload"]
232

233
234
235
236
            model = FullyShardedDataParallel(before_wrap_model, **config)
            print(f"model.ssd_offload {model.ssd_offload}")
            if not model.ssd_offload and not model.move_params_to_cpu:
                model = model.cuda()
237

238
            self._eval_with_config(model, autocast=config["mixed_precision"])
239

240
241
            # Get the named parameters after wrapping to compare.
            after_wrap_params = model.named_parameters()
242

243
244
245
246
247
248
            if not config.get("flatten_parameters", False):
                for before_nm, after_nm in zip(before_wrap_params, after_wrap_params):
                    assert before_nm[0] == after_nm[0]
            else:
                named_params_flat = [p for p in after_wrap_params][0][0]
                assert "flat_param_0" in named_params_flat
249

250
            after_wrap_params = model.named_parameters()
251

252
253
254
            for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params):
                assert before_nm[0] == after_nm_original[0]
                torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].shape)
255
256
257
258
259
260
261
262


class TestSsdLoading(DistributedTest):
    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_ssd_offloading_eval(self, config):
        test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
        spawn_and_init(test_fn)

263
    @parameterized.expand(CONFIG, name_func=rename_test)
264
265
266
    def test_transformer_parameterized(self, config):
        spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_ssd_offloading_train_flatten_params_wrapper(self, config):
        test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
        spawn_and_init(test_fn)

    @classmethod
    def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
        SIZE = 16 * 16
        model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)

        with tempfile.TemporaryDirectory() as current_tempdir:
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
            config["flatten_parameters"] = True

            nested_wrapping = config["nested_wrapping"]
            del config["nested_wrapping"]

            if nested_wrapping:
                model = FullyShardedDataParallel(
                    NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
                )
            else:
                model = FullyShardedDataParallel(model, **config)
            model_device = torch.device("cuda")
            model.train()
            optim = torch.optim.SGD(model.parameters(), lr=4, momentum=0.9)
            # Inputs always cuda regardless of move_grads_cpu, or model.device
            with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)):
                for i in range(10):
                    optim.zero_grad()
                    input = model.get_input(torch.device("cuda"))
                    output = model(*input)
                    loss = model.module.get_loss(input, output).to(model_device)
                    assert loss.dtype == torch.float32

                    model.module.run_backward(loss)
                    optim.step()

            if isinstance(model, FullyShardedDataParallel):
                model.assert_state(TrainingState.IDLE)

308
309
310
311
312
313
314
    @classmethod
    def _test_ssd_offload_eval(self, rank, group, config):
        model = TransformerWithSharedParams(group)
        state_dict = model.state_dict()

        nested_wrapping = config["nested_wrapping"]
        del config["nested_wrapping"]
315
        config["flatten_parameters"] = True
316

317
        with tempfile.TemporaryDirectory() as current_tempdir:
318
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
319
320
321
322
323
324
            if nested_wrapping:
                model = FullyShardedDataParallel(
                    NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
                )
            else:
                model = FullyShardedDataParallel(model, **config)
325

326
            self._eval_with_config(model, autocast=config["mixed_precision"])
327

328
329
330
331
            # With SSD offload only local_state_dict will work. We can support global
            # state dict if we think it is necessary.
            state_dict = model.local_state_dict()
            model.load_local_state_dict(state_dict)
332

333
            self._eval_with_config(model, config["mixed_precision"])
334
335
336
337
338
339
340
341
342
343
344


class TransformerWithSharedParams(nn.Module):
    def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        torch.manual_seed(0)  # keep everything deterministic
        assert d_vocab >= 12  # we use torch.arange(12) as input
        self.embed_tokens = nn.Embedding(d_vocab, d_model)
        self.transformer = nn.Transformer(
345
346
347
348
349
            d_model=d_model,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=8,
            dropout=0.1,
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        )
        self.output_proj = nn.Linear(d_model, d_vocab)

        # share the embedding and output projection weights
        self.output_proj.weight = self.embed_tokens.weight
        self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,)))
        self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))

        self.bs = 2
        self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        src = torch.arange(12, device=device).view(6, self.bs)  # T x B
        tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs)  # T x B
        return (src, tgt)

    def forward(self, src_ids, tgt_ids):
        src = self.embed_tokens(src_ids)
        src = src + self.vocab_bias + self.long_buffer.type_as(src)
        tgt = self.embed_tokens(tgt_ids)
        tgt = self.bn(tgt)
        x = self.transformer(src, tgt)
        return self.output_proj(x)

    def get_loss(self, input, output):
        _, tgt = input
        return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum")

    def run_backward(self, loss):
        loss.backward()


class NestedWrappedModule(nn.Module):
    def __init__(self, group, wrapper_config, wrap_everything=False, checkpoint=False):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        self.wrapper_config = wrapper_config

        def _maybe_wrap(layer):
            if wrapper_config is not None:
                return FullyShardedDataParallel(layer, group, **wrapper_config)
            return layer

        torch.manual_seed(0)  # keep everything deterministic
        self.module = nn.Sequential(
            nn.Linear(8, 4),
398
399
400
401
402
403
            _maybe_wrap(
                nn.Sequential(
                    _maybe_wrap(nn.Linear(4, 16)),
                    nn.Linear(16, 16),
                )
            ),
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
            _maybe_wrap(nn.Linear(16, 4)),
            nn.Linear(4, 8),
        )

        # Wrap all modules triggers a corner case where root FSDP doesn't have any params.
        # Test it with checkpoint_wrapper as well to validate final backward callback
        # is queued correctly when root FSDP does not have any params and every layer is
        # wrapped as FSDP(checkpoint(module)).
        if wrap_everything:
            if checkpoint:
                self.module = nn.Sequential(
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))),
                )
            else:
                self.module = nn.Sequential(
                    _maybe_wrap(nn.Linear(8, 4)),
                    _maybe_wrap(nn.Linear(4, 16)),
                    _maybe_wrap(nn.Linear(16, 4)),
                    _maybe_wrap(nn.Linear(4, 8)),
                )

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        return (torch.rand(4, 8, device=device),)

    def forward(self, x):
        return self.module(x)

    def get_loss(self, input, output):
        loss = output.sum()
        return loss

    def run_backward(self, loss):
        loss.backward()


def spawn_and_init(fn, args=None, **spawn_kwargs):
    if args is None:
        args = ()

    run_fn = functools.partial(init_and_run, fn, args)
    spawn_for_all_world_sizes(run_fn, **spawn_kwargs)


def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
    dist_init(rank, world_size, filename, filename_rpc)
    group = torch.distributed.new_group()
    fn(rank, group, *args)


if __name__ == "__main__":
    unittest.main()