test_fsdp_offload.py 20.2 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import itertools
import sys
9
import tempfile
10
11
12
13
14
15
16
17
18
import time
import unittest

from parameterized import parameterized
import pytest
import torch
from torch import nn
import torch.distributed

19
20
21
22
23
24
25
try:
    import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
    # Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
    pytestmark = pytest.mark.skipif(True, reason=ie.msg)
    pass

26
from fair_dev.testing.testing import dist_init, spawn_for_all_world_sizes
27
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
28
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod


class DistributedTest(unittest.TestCase):
    def setUp(self):
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA not available, skipping test")
        if sys.platform == "win32":
            raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
        if torch.cuda.device_count() < 2:
            raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")

    @staticmethod
    def _eval_with_config(model, autocast):
        model.eval()
        model_device = torch.device("cuda")
        with torch.cuda.amp.autocast(enabled=autocast):
            # Inputs always cuda regardless of move_grads_cpu, or model.device
            input = model.module.get_input(torch.device("cuda"))
            output = model(*input)
            loss = model.module.get_loss(input, output).to(model_device)
        assert loss.dtype == torch.float32
        if isinstance(model, FullyShardedDataParallel):
            model.assert_state(TrainingState.IDLE)
        return loss.detach()

    @staticmethod
    def _eval_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None):
        model.eval()
        # Inputs always cuda regardless of move_grads_cpu, or model.device
        input = model.module.get_input(torch.device("cuda"))

        for _ in range(num_steps):
            with torch.cuda.amp.autocast(enabled=autocast):
                output = model(*input)

    @classmethod
    def _test_identical_outputs_eval(
69
70
71
72
73
74
75
76
77
        cls,
        model_init_fn,
        config,
        rank,
        group,
        num_steps=2,
        use_cuda=True,
        lr=0.01,
        ref_ddp_fn=None,
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    ):
        if config.get("mixed_precision", False):
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank, process_group=group
            )
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._eval_with_config(model, autocast)
        ref_state_dict = model.module.state_dict()
        if config.get("cpu_offload", False):
            for k in ref_state_dict.keys():
                ref_state_dict[k] = ref_state_dict[k].cpu()

        # Confirm we get the same behavior using FullyShardedDataParallel.
104
105
        if config.get("ssd_offload", False):
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
106
107
            # ssd offload only supports flatten_params ATM
            config["flatten_parameters"] = True
108
109

        del config["ssd_offload"]
110
        model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
111
        if not model.ssd_offload and not model.move_params_to_cpu:
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            if use_cuda:
                model = model.cuda()
            else:
                assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._eval_with_config(model, autocast)

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
        except (AssertionError, RuntimeError) as e:
            raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
        if config.get("flatten_parameters", True):
            metadata = model.local_metadata_dict()
            assert isinstance(metadata, dict)


127
keys = ["reshard_after_forward", "mixed_precision", "nested_wrapping"]
128
129
130
131
CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]


def rename_test(testcase_func, param_num, param):
132
133
134
135
    return "%s_%s" % (
        testcase_func.__name__,
        parameterized.to_safe_name(str(param.args)),
    )
136
137
138
139


class TestSsdMemory(DistributedTest):
    def test_memory_benchmark(self):
140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        test_fn = functools.partial(self._test_memory_benchmark, config={})
        spawn_and_init(test_fn)

    @classmethod
    def _test_memory_benchmark(self, rank, group, config):
        time_keeper = TimeKeeper()

        SIZE = 8 * 8
        time_keeper.print_time("START", 1.0)
        a = torch.empty(1)
        b = a.cuda()
        # wait for cuda to fully load
        time.sleep(1)
        time_keeper.print_time("INIT_CUDA", 1.0)
        model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
        time_keeper.print_time("CPU_MODEL", 1.0)

158
        with tempfile.TemporaryDirectory() as current_tempdir:
159
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
160

161
162
            model = FullyShardedDataParallel(model, **config)
            time_keeper.print_time("FSDP_MODEL", 1.0)
163

164
165
            self._eval_for_several_steps(model, 1, autocast=False)
            time_keeper.print_time("EVAL")
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218


class SimpleLinear(nn.Module):
    def __init__(self, group, input_size, output_size, layers=1, **unused_kwargs):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        self.input_size = input_size
        self.output_size = output_size
        torch.manual_seed(0)  # keep everything deterministic
        seq_layers = []
        for i in range(layers):
            seq_layers.append(nn.Linear(input_size, output_size, bias=False))
        self.module = nn.Sequential(*seq_layers)
        self.bs = 2

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        src = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
        tgt = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
        return (src, tgt)

    def forward(self, src_ids, tgt_ids):
        param_devices = [p.device for p in self.module.parameters()]

        return self.module(src_ids)

    def get_loss(self, input, output):
        _, tgt = input

        return nn.functional.binary_cross_entropy_with_logits(output, tgt)

    def run_backward(self, loss):
        loss.backward()


KEYS = ["ssd_offload", "flatten_parameters", "mixed_precision", "move_params_to_cpu"]
CONFIG = [[dict(zip(KEYS, config))] for config in itertools.product([True, False], repeat=len(KEYS))]


class TimeKeeper:
    def __init__(self):
        self.start_time = time.time()

    def print_time(self, s: str, wait_time: float = 1.0):
        cur_time = time.time()
        print(f"@time: {cur_time - self.start_time:0.2f} {s}")
        time.sleep(wait_time)


class TestModuleProperties(DistributedTest):
    @parameterized.expand(CONFIG, name_func=rename_test)
    def test_named_parameters(self, config):
219

220
221
222
223
224
225
226
227
228
        test_fn = functools.partial(self._test_named_params, config=config)
        spawn_and_init(test_fn)

    @classmethod
    def _test_named_params(self, rank, group, config):
        # Get the named parameters before wrapping.
        before_wrap_model = TransformerWithSharedParams(group)
        before_wrap_params = before_wrap_model.named_parameters()

229
230
        with tempfile.TemporaryDirectory() as current_tempdir:
            if config["ssd_offload"]:
231
232
233
                config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
                # ssd offload only supports flatten_params ATM
                config["flatten_parameters"] = True
234
            del config["ssd_offload"]
235

236
237
238
239
            model = FullyShardedDataParallel(before_wrap_model, **config)
            print(f"model.ssd_offload {model.ssd_offload}")
            if not model.ssd_offload and not model.move_params_to_cpu:
                model = model.cuda()
240

241
            self._eval_with_config(model, autocast=config["mixed_precision"])
242

243
244
            # Get the named parameters after wrapping to compare.
            after_wrap_params = model.named_parameters()
245

246
247
248
249
250
251
            if not config.get("flatten_parameters", False):
                for before_nm, after_nm in zip(before_wrap_params, after_wrap_params):
                    assert before_nm[0] == after_nm[0]
            else:
                named_params_flat = [p for p in after_wrap_params][0][0]
                assert "flat_param_0" in named_params_flat
252

253
            after_wrap_params = model.named_parameters()
254

255
256
257
            for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params):
                assert before_nm[0] == after_nm_original[0]
                torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].shape)
258
259
260
261
262


class TestSsdLoading(DistributedTest):
    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_ssd_offloading_eval(self, config):
263

264
265
266
        test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
        spawn_and_init(test_fn)

267
    @parameterized.expand(CONFIG, name_func=rename_test)
268
    def test_transformer_parameterized(self, config):
269

270
271
        spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))

272
273
    @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
    def test_ssd_offloading_train_flatten_params_wrapper(self, config):
274

275
276
277
278
279
280
        test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
        spawn_and_init(test_fn)

    @classmethod
    def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
        SIZE = 16 * 16
281
282
        LR = 0.01
        MOMENTUM = 0.1
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)

        with tempfile.TemporaryDirectory() as current_tempdir:
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
            config["flatten_parameters"] = True

            nested_wrapping = config["nested_wrapping"]
            del config["nested_wrapping"]

            if nested_wrapping:
                model = FullyShardedDataParallel(
                    NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
                )
            else:
                model = FullyShardedDataParallel(model, **config)
            model_device = torch.device("cuda")
            model.train()
300
            optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
301
302
303
304
305
306
307
308
309

            checkpoint_file = tempfile.NamedTemporaryFile()
            checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")

            pre_checkpoint_last_output = None
            post_checkpoint_last_output = None

            ITERATIONS = 10

310
311
            # Inputs always cuda regardless of move_grads_cpu, or model.device
            with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)):
312
313
314
315
316
                for i in range(ITERATIONS):
                    optim.zero_grad()
                    input = model.get_input(torch.device("cuda"))
                    output = model(*input)
                    pre_checkpoint_last_output = output
317
318
319
320
321
                    """
                    param_itr = iter(model.named_parameters())
                    p_name, p_val = next(param_itr)
                    print(f"i={i} pre_checkpoint {p_name} = {p_val[0].item()}")
                    """
322
323
324
325
326
327
328
329
330
                    loss = model.module.get_loss(input, output).to(model_device)
                    assert loss.dtype == torch.float32
                    model.module.run_backward(loss)
                    optim.step()
                    if i == 0:
                        with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
                            # so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name)
                            torch.save({"model": model.state_dict()}, checkpoint_file.name)
                        # reset momentum just after checkpoint save
331
                        optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
332
333
334
335

                checkpoint = torch.load(checkpoint_file.name)
                model.load_state_dict(checkpoint["model"])
                # reset momentum just after checkpoint load
336
                optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
337
338
                # do more iterations after loading checkpoint
                for i in range(ITERATIONS - 1):
339
340
341
                    optim.zero_grad()
                    input = model.get_input(torch.device("cuda"))
                    output = model(*input)
342
                    post_checkpoint_last_output = output
343
344
345
346
347
                    """
                    param_itr = iter(model.named_parameters())
                    p_name, p_val = next(param_itr)
                    print(f"i={i} post_checkpoint {p_name} = {p_val[0].item()}")
                    """
348
349
350
351
352
353
                    loss = model.module.get_loss(input, output).to(model_device)
                    assert loss.dtype == torch.float32

                    model.module.run_backward(loss)
                    optim.step()

354
355
            # Verify output of checkpoint load + run is equal to original output
            assert torch.equal(pre_checkpoint_last_output, post_checkpoint_last_output)
356
357
358
            if isinstance(model, FullyShardedDataParallel):
                model.assert_state(TrainingState.IDLE)

359
360
361
362
363
364
365
    @classmethod
    def _test_ssd_offload_eval(self, rank, group, config):
        model = TransformerWithSharedParams(group)
        state_dict = model.state_dict()

        nested_wrapping = config["nested_wrapping"]
        del config["nested_wrapping"]
366
        config["flatten_parameters"] = True
367

368
        with tempfile.TemporaryDirectory() as current_tempdir:
369
            config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
370
371
372
373
374
375
            if nested_wrapping:
                model = FullyShardedDataParallel(
                    NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
                )
            else:
                model = FullyShardedDataParallel(model, **config)
376

377
            self._eval_with_config(model, autocast=config["mixed_precision"])
378

379
380
381
382
            # With SSD offload only local_state_dict will work. We can support global
            # state dict if we think it is necessary.
            state_dict = model.local_state_dict()
            model.load_local_state_dict(state_dict)
383

384
            self._eval_with_config(model, config["mixed_precision"])
385
386
387
388
389
390
391
392
393
394
395


class TransformerWithSharedParams(nn.Module):
    def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        torch.manual_seed(0)  # keep everything deterministic
        assert d_vocab >= 12  # we use torch.arange(12) as input
        self.embed_tokens = nn.Embedding(d_vocab, d_model)
        self.transformer = nn.Transformer(
396
397
398
399
400
            d_model=d_model,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=8,
            dropout=0.1,
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        )
        self.output_proj = nn.Linear(d_model, d_vocab)

        # share the embedding and output projection weights
        self.output_proj.weight = self.embed_tokens.weight
        self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,)))
        self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))

        self.bs = 2
        self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        src = torch.arange(12, device=device).view(6, self.bs)  # T x B
        tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs)  # T x B
        return (src, tgt)

    def forward(self, src_ids, tgt_ids):
        src = self.embed_tokens(src_ids)
        src = src + self.vocab_bias + self.long_buffer.type_as(src)
        tgt = self.embed_tokens(tgt_ids)
        tgt = self.bn(tgt)
        x = self.transformer(src, tgt)
        return self.output_proj(x)

    def get_loss(self, input, output):
        _, tgt = input
        return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum")

    def run_backward(self, loss):
        loss.backward()


class NestedWrappedModule(nn.Module):
    def __init__(self, group, wrapper_config, wrap_everything=False, checkpoint=False):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        self.wrapper_config = wrapper_config

        def _maybe_wrap(layer):
            if wrapper_config is not None:
                return FullyShardedDataParallel(layer, group, **wrapper_config)
            return layer

        torch.manual_seed(0)  # keep everything deterministic
        self.module = nn.Sequential(
            nn.Linear(8, 4),
449
450
451
452
453
454
            _maybe_wrap(
                nn.Sequential(
                    _maybe_wrap(nn.Linear(4, 16)),
                    nn.Linear(16, 16),
                )
            ),
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
            _maybe_wrap(nn.Linear(16, 4)),
            nn.Linear(4, 8),
        )

        # Wrap all modules triggers a corner case where root FSDP doesn't have any params.
        # Test it with checkpoint_wrapper as well to validate final backward callback
        # is queued correctly when root FSDP does not have any params and every layer is
        # wrapped as FSDP(checkpoint(module)).
        if wrap_everything:
            if checkpoint:
                self.module = nn.Sequential(
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))),
                )
            else:
                self.module = nn.Sequential(
                    _maybe_wrap(nn.Linear(8, 4)),
                    _maybe_wrap(nn.Linear(4, 16)),
                    _maybe_wrap(nn.Linear(16, 4)),
                    _maybe_wrap(nn.Linear(4, 8)),
                )

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        return (torch.rand(4, 8, device=device),)

    def forward(self, x):
        return self.module(x)

    def get_loss(self, input, output):
        loss = output.sum()
        return loss

    def run_backward(self, loss):
        loss.backward()


def spawn_and_init(fn, args=None, **spawn_kwargs):
    if args is None:
        args = ()

    run_fn = functools.partial(init_and_run, fn, args)
499
500
501
502
503
504

    # Below 3 lines are to easily enable single-process debugging
    # _, filename = tempfile.mkstemp()
    # _, filename_rpc = tempfile.mkstemp()
    # run_fn(0, 1, filename, filename_rpc)

505
506
507
508
509
510
511
512
513
514
515
    spawn_for_all_world_sizes(run_fn, **spawn_kwargs)


def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
    dist_init(rank, world_size, filename, filename_rpc)
    group = torch.distributed.new_group()
    fn(rank, group, *args)


if __name__ == "__main__":
    unittest.main()