test_fsdp_input.py 2.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with different input types. """

import os
import random

import pytest
import torch
from torch.nn import Linear, Module
from torch.optim import SGD

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import skip_if_no_cuda, torch_version


# We only test on GPU since mix-precision only really works on GPU.
@skip_if_no_cuda
@pytest.mark.parametrize(
    "fsdp_config", [{}, {"mixed_precision": True}],
)
@pytest.mark.parametrize("input_cls", [dict, list])
def test_it(fsdp_config, input_cls):
    """Test FSDP with input being a list or a dict, only single GPU."""
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    # Random port in case the next test run quickly, same port would cause conflict.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    try:
        assert isinstance(fsdp_config, dict), str(fsdp_config)

        class Model(Module):
            def __init__(self):
                super().__init__()
                self.layer = Linear(4, 4)

            def forward(self, input):
                if isinstance(input, list):
                    input = input[0]
                else:
                    assert isinstance(input, dict), input
                    input = input["in"]
                return self.layer(input)

        model = FSDP(Model(), **fsdp_config).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for _ in range(5):
            in_data = torch.rand(64, 4).cuda()
            in_data.requires_grad = True
            if input_cls is list:
                in_data = [in_data]
            else:
                assert input_cls is dict
                in_data = {"in": in_data}

            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        model.assert_state(TrainingState.IDLE)

    finally:
        # Clean-up is important or the next test in this file may fail to init the PG.
        torch.distributed.destroy_process_group()
        del os.environ["MASTER_ADDR"]
        del os.environ["MASTER_PORT"]