test_fsdp_offload.py 20.2 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import itertools
import sys
9
import tempfile
10
11
12
13
14
15
16
17
18
import time
import unittest

from parameterized import parameterized
import pytest
import torch
from torch import nn
import torch.distributed

19
import fairscale.experimental.nn.ssd_offload as so
20
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
21
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
22
from fairscale.utils import torch_version
23
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
24
25

# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
26
print(f"torch version {torch_version()}")
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")


# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod


class DistributedTest(unittest.TestCase):
    def setUp(self):
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA not available, skipping test")
        if sys.platform == "win32":
            raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
        if torch.cuda.device_count() < 2:
            raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")

    @staticmethod
    def _eval_with_config(model, autocast):
        model.eval()
        model_device = torch.device("cuda")
        with torch.cuda.amp.autocast(enabled=autocast):
            # Inputs always cuda regardless of move_grads_cpu, or model.device
            input = model.module.get_input(torch.device("cuda"))
            output = model(*input)
            loss = model.module.get_loss(input, output).to(model_device)
        assert loss.dtype == torch.float32
        if isinstance(model, FullyShardedDataParallel):
            model.assert_state(TrainingState.IDLE)
        return loss.detach()

    @staticmethod
    def _eval_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None):
        model.eval()
        # Inputs always cuda regardless of move_grads_cpu, or model.device
        input = model.module.get_input(torch.device("cuda"))

        for _ in range(num_steps):
            with torch.cuda.amp.autocast(enabled=autocast):
                output = model(*input)

    @classmethod
    def _test_identical_outputs_eval(
69
70
71
72
73
74
75
76
77
        cls,
        model_init_fn,
        config,
        rank,
        group,
        num_steps=2,
        use_cuda=True,
        lr=0.01,
        ref_ddp_fn=None,
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    ):
        if config.get("mixed_precision", False):
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank, process_group=group
            )
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._eval_with_config(model, autocast)
        ref_state_dict = model.module.state_dict()
        if config.get("cpu_offload", False):
            for k in ref_state_dict.keys():
                ref_state_dict[k] = ref_state_dict[k].cpu()

        # Confirm we get the same behavior using FullyShardedDataParallel.
104
105
        if config.get("ssd_offload", False):
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
106
107
            # ssd offload only supports flatten_params ATM
            config["flatten_parameters"] = True
108
109

        del config["ssd_offload"]
110
        model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
111
        if not model.ssd_offload and not model.move_params_to_cpu:
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            if use_cuda:
                model = model.cuda()
            else:
                assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._eval_with_config(model, autocast)

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
        except (AssertionError, RuntimeError) as e:
            raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
        if config.get("flatten_parameters", True):
            metadata = model.local_metadata_dict()
            assert isinstance(metadata, dict)


127
keys = ["reshard_after_forward", "mixed_precision", "nested_wrapping"]
128
129
130
131
CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]


def rename_test(testcase_func, param_num, param):
132
133
134
135
    return "%s_%s" % (
        testcase_func.__name__,
        parameterized.to_safe_name(str(param.args)),
    )
136
137
138
139


class TestSsdMemory(DistributedTest):
    def test_memory_benchmark(self):
140
141
142
        if torch_version() >= (1, 12, 0):
            pytest.skip("to be fixed")

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        test_fn = functools.partial(self._test_memory_benchmark, config={})
        spawn_and_init(test_fn)

    @classmethod
    def _test_memory_benchmark(self, rank, group, config):
        time_keeper = TimeKeeper()

        SIZE = 8 * 8
        time_keeper.print_time("START", 1.0)
        a = torch.empty(1)
        b = a.cuda()
        # wait for cuda to fully load
        time.sleep(1)
        time_keeper.print_time("INIT_CUDA", 1.0)
        model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
        time_keeper.print_time("CPU_MODEL", 1.0)

160
        with tempfile.TemporaryDirectory() as current_tempdir:
161
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
162

163
164
            model = FullyShardedDataParallel(model, **config)
            time_keeper.print_time("FSDP_MODEL", 1.0)
165

166
167
            self._eval_for_several_steps(model, 1, autocast=False)
            time_keeper.print_time("EVAL")
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220


class SimpleLinear(nn.Module):
    def __init__(self, group, input_size, output_size, layers=1, **unused_kwargs):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        self.input_size = input_size
        self.output_size = output_size
        torch.manual_seed(0)  # keep everything deterministic
        seq_layers = []
        for i in range(layers):
            seq_layers.append(nn.Linear(input_size, output_size, bias=False))
        self.module = nn.Sequential(*seq_layers)
        self.bs = 2

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        src = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
        tgt = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
        return (src, tgt)

    def forward(self, src_ids, tgt_ids):
        param_devices = [p.device for p in self.module.parameters()]

        return self.module(src_ids)

    def get_loss(self, input, output):
        _, tgt = input

        return nn.functional.binary_cross_entropy_with_logits(output, tgt)

    def run_backward(self, loss):
        loss.backward()


KEYS = ["ssd_offload", "flatten_parameters", "mixed_precision", "move_params_to_cpu"]
CONFIG = [[dict(zip(KEYS, config))] for config in itertools.product([True, False], repeat=len(KEYS))]


class TimeKeeper:
    def __init__(self):
        self.start_time = time.time()

    def print_time(self, s: str, wait_time: float = 1.0):
        cur_time = time.time()
        print(f"@time: {cur_time - self.start_time:0.2f} {s}")
        time.sleep(wait_time)


class TestModuleProperties(DistributedTest):
    @parameterized.expand(CONFIG, name_func=rename_test)
    def test_named_parameters(self, config):
221
222
223
        if torch_version() >= (1, 12, 0):
            pytest.skip("to be fixed")

224
225
226
227
228
229
230
231
232
        test_fn = functools.partial(self._test_named_params, config=config)
        spawn_and_init(test_fn)

    @classmethod
    def _test_named_params(self, rank, group, config):
        # Get the named parameters before wrapping.
        before_wrap_model = TransformerWithSharedParams(group)
        before_wrap_params = before_wrap_model.named_parameters()

233
234
        with tempfile.TemporaryDirectory() as current_tempdir:
            if config["ssd_offload"]:
235
236
237
                config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
                # ssd offload only supports flatten_params ATM
                config["flatten_parameters"] = True
238
            del config["ssd_offload"]
239

240
241
242
243
            model = FullyShardedDataParallel(before_wrap_model, **config)
            print(f"model.ssd_offload {model.ssd_offload}")
            if not model.ssd_offload and not model.move_params_to_cpu:
                model = model.cuda()
244

245
            self._eval_with_config(model, autocast=config["mixed_precision"])
246

247
248
            # Get the named parameters after wrapping to compare.
            after_wrap_params = model.named_parameters()
249

250
251
252
253
254
255
            if not config.get("flatten_parameters", False):
                for before_nm, after_nm in zip(before_wrap_params, after_wrap_params):
                    assert before_nm[0] == after_nm[0]
            else:
                named_params_flat = [p for p in after_wrap_params][0][0]
                assert "flat_param_0" in named_params_flat
256

257
            after_wrap_params = model.named_parameters()
258

259
260
261
            for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params):
                assert before_nm[0] == after_nm_original[0]
                torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].shape)
262
263
264
265
266


class TestSsdLoading(DistributedTest):
    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_ssd_offloading_eval(self, config):
267
268
269
        if torch_version() >= (1, 12, 0):
            pytest.skip("to be fixed")

270
271
272
        test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
        spawn_and_init(test_fn)

273
    @parameterized.expand(CONFIG, name_func=rename_test)
274
    def test_transformer_parameterized(self, config):
275
276
277
        if torch_version() >= (1, 12, 0):
            pytest.skip("to be fixed")

278
279
        spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))

280
281
    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_ssd_offloading_train_flatten_params_wrapper(self, config):
282
283
284
        if torch_version() >= (1, 12, 0):
            pytest.skip("to be fixed")

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
        spawn_and_init(test_fn)

    @classmethod
    def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
        SIZE = 16 * 16
        model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)

        with tempfile.TemporaryDirectory() as current_tempdir:
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
            config["flatten_parameters"] = True

            nested_wrapping = config["nested_wrapping"]
            del config["nested_wrapping"]

            if nested_wrapping:
                model = FullyShardedDataParallel(
                    NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
                )
            else:
                model = FullyShardedDataParallel(model, **config)
            model_device = torch.device("cuda")
            model.train()
308
309
310
311
312
313
314
315
316
317
            optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

            checkpoint_file = tempfile.NamedTemporaryFile()
            checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")

            pre_checkpoint_last_output = None
            post_checkpoint_last_output = None

            ITERATIONS = 10

318
319
            # Inputs always cuda regardless of move_grads_cpu, or model.device
            with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)):
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
                for i in range(ITERATIONS):
                    optim.zero_grad()
                    input = model.get_input(torch.device("cuda"))
                    output = model(*input)
                    pre_checkpoint_last_output = output
                    loss = model.module.get_loss(input, output).to(model_device)
                    assert loss.dtype == torch.float32

                    model.module.run_backward(loss)
                    optim.step()
                    if i == 0:
                        with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
                            # so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name)
                            torch.save({"model": model.state_dict()}, checkpoint_file.name)
                        # reset momentum just after checkpoint save
                        optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

                checkpoint = torch.load(checkpoint_file.name)
                model.load_state_dict(checkpoint["model"])
                # reset momentum just after checkpoint load
                optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
                # do more iterations after loading checkpoint
                for i in range(ITERATIONS - 1):
343
344
345
                    optim.zero_grad()
                    input = model.get_input(torch.device("cuda"))
                    output = model(*input)
346
                    post_checkpoint_last_output = output
347
348
349
350
351
352
                    loss = model.module.get_loss(input, output).to(model_device)
                    assert loss.dtype == torch.float32

                    model.module.run_backward(loss)
                    optim.step()

353
354
            # Verify output of checkpoint load + run is equal to original output
            assert torch.equal(pre_checkpoint_last_output, post_checkpoint_last_output)
355
356
357
            if isinstance(model, FullyShardedDataParallel):
                model.assert_state(TrainingState.IDLE)

358
359
360
361
362
363
364
    @classmethod
    def _test_ssd_offload_eval(self, rank, group, config):
        model = TransformerWithSharedParams(group)
        state_dict = model.state_dict()

        nested_wrapping = config["nested_wrapping"]
        del config["nested_wrapping"]
365
        config["flatten_parameters"] = True
366

367
        with tempfile.TemporaryDirectory() as current_tempdir:
368
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
369
370
371
372
373
374
            if nested_wrapping:
                model = FullyShardedDataParallel(
                    NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
                )
            else:
                model = FullyShardedDataParallel(model, **config)
375

376
            self._eval_with_config(model, autocast=config["mixed_precision"])
377

378
379
380
381
            # With SSD offload only local_state_dict will work. We can support global
            # state dict if we think it is necessary.
            state_dict = model.local_state_dict()
            model.load_local_state_dict(state_dict)
382

383
            self._eval_with_config(model, config["mixed_precision"])
384
385
386
387
388
389
390
391
392
393
394


class TransformerWithSharedParams(nn.Module):
    def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        torch.manual_seed(0)  # keep everything deterministic
        assert d_vocab >= 12  # we use torch.arange(12) as input
        self.embed_tokens = nn.Embedding(d_vocab, d_model)
        self.transformer = nn.Transformer(
395
396
397
398
399
            d_model=d_model,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=8,
            dropout=0.1,
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        )
        self.output_proj = nn.Linear(d_model, d_vocab)

        # share the embedding and output projection weights
        self.output_proj.weight = self.embed_tokens.weight
        self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,)))
        self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))

        self.bs = 2
        self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        src = torch.arange(12, device=device).view(6, self.bs)  # T x B
        tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs)  # T x B
        return (src, tgt)

    def forward(self, src_ids, tgt_ids):
        src = self.embed_tokens(src_ids)
        src = src + self.vocab_bias + self.long_buffer.type_as(src)
        tgt = self.embed_tokens(tgt_ids)
        tgt = self.bn(tgt)
        x = self.transformer(src, tgt)
        return self.output_proj(x)

    def get_loss(self, input, output):
        _, tgt = input
        return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum")

    def run_backward(self, loss):
        loss.backward()


class NestedWrappedModule(nn.Module):
    def __init__(self, group, wrapper_config, wrap_everything=False, checkpoint=False):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        self.wrapper_config = wrapper_config

        def _maybe_wrap(layer):
            if wrapper_config is not None:
                return FullyShardedDataParallel(layer, group, **wrapper_config)
            return layer

        torch.manual_seed(0)  # keep everything deterministic
        self.module = nn.Sequential(
            nn.Linear(8, 4),
448
449
450
451
452
453
            _maybe_wrap(
                nn.Sequential(
                    _maybe_wrap(nn.Linear(4, 16)),
                    nn.Linear(16, 16),
                )
            ),
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
            _maybe_wrap(nn.Linear(16, 4)),
            nn.Linear(4, 8),
        )

        # Wrap all modules triggers a corner case where root FSDP doesn't have any params.
        # Test it with checkpoint_wrapper as well to validate final backward callback
        # is queued correctly when root FSDP does not have any params and every layer is
        # wrapped as FSDP(checkpoint(module)).
        if wrap_everything:
            if checkpoint:
                self.module = nn.Sequential(
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))),
                )
            else:
                self.module = nn.Sequential(
                    _maybe_wrap(nn.Linear(8, 4)),
                    _maybe_wrap(nn.Linear(4, 16)),
                    _maybe_wrap(nn.Linear(16, 4)),
                    _maybe_wrap(nn.Linear(4, 8)),
                )

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        return (torch.rand(4, 8, device=device),)

    def forward(self, x):
        return self.module(x)

    def get_loss(self, input, output):
        loss = output.sum()
        return loss

    def run_backward(self, loss):
        loss.backward()


def spawn_and_init(fn, args=None, **spawn_kwargs):
    if args is None:
        args = ()

    run_fn = functools.partial(init_and_run, fn, args)
498
499
500
501
502
503

    # Below 3 lines are to easily enable single-process debugging
    # _, filename = tempfile.mkstemp()
    # _, filename_rpc = tempfile.mkstemp()
    # run_fn(0, 1, filename, filename_rpc)

504
505
506
507
508
509
510
511
512
513
514
    spawn_for_all_world_sizes(run_fn, **spawn_kwargs)


def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
    dist_init(rank, world_size, filename, filename_rpc)
    group = torch.distributed.new_group()
    fn(rank, group, *args)


if __name__ == "__main__":
    unittest.main()