test_fsdp_uneven.py 5.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with uneven parameter shards. """

import tempfile

import pytest
import torch
from torch import Tensor
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from torch.optim import SGD

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState
23
24
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown
25
26
27
28
29
30


def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

31
32
    my_lr = 0.1

33
34
35
36
37
38
39
    device = torch.device("cuda")
    if fsdp_config.get("mixed_precision", False):
        dtype = torch.float16
        fsdp_config["fp32_reduce_scatter"] = True
    else:
        dtype = torch.float32

40
41
    if test_case["assert_ref_out"]:
        with torch.no_grad():
42
            # Compute one iteration local output.
43
44
45
            fp32_weight = model.weight.T.clone().to(device)
            weight = fp32_weight.to(dtype)
            v = torch.Tensor(test_case["inputs"][0][rank]).to(device, dtype)
46
47
            ref_forward_output_my_rank = torch.matmul(v, weight)
            # Compute one iteration global weight update.
48
49
50
51
52
            v = torch.Tensor(test_case["inputs"][0][:world_size]).to(device, dtype)
            grad = v.float().sum(0).repeat(weight.shape[0], 1).div(world_size)
            ref_weight_out = fp32_weight - grad.T * my_lr
            assert ref_weight_out.dtype == torch.float32
    model.to(device)  # not dtype, since FSDP will manage mixed precision internally
53
54
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
55
    optim = SGD(model.parameters(), lr=my_lr)
56
57
58
59
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
60
        in_data = Tensor(in_data[rank]).to(device, dtype)
61
        out = model(in_data)
62
        out.float().sum().backward()
63
64
        optim.step()
        optim.zero_grad()
65
66
67
68
69
70
        if test_case["assert_ref_out"]:
            with model.summon_full_params():
                weight_out = model.module.weight.data.T.clone()
            # make sure we can do more fwd/bwd
            loss = model(in_data)
            loss.sum().backward()
71
72

    if test_case["assert_ref_out"]:
73
74
        torch.testing.assert_allclose(ref_forward_output_my_rank, out)
        torch.testing.assert_allclose(ref_weight_out, weight_out)
75
76
77
78
79
80
81
82

    model.assert_state(TrainingState.IDLE)
    teardown()


@skip_if_single_gpu
@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}])
@pytest.mark.parametrize(
83
    "fsdp_config", [{}, {"flatten_parameters": False}, {"mixed_precision": True}],
84
85
86
87
88
)
@pytest.mark.parametrize("world_size", list(range(2, 9)))
def test_one_iteration(world_size, test_case, fsdp_config):
    """Test FSDP with uneven divide of parameter shards."""
    if torch_version() < (1, 6, 0):
Min Xu's avatar
Min Xu committed
89
        pytest.skip("older pytorch doesn't support reduce_scatter")
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs.")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    # TODO (Min): we may want to extend this to a simple 2 layer model so that it covers
    #             more cases in FSDP. Also, assert_ref_out can be extended to multiple
    #             iterations. This could be a good bootcamp task. I should file a github
    #             issue once we merge.
    model = Linear(3, 3, bias=False)
    mp.spawn(
        _test_func,
        args=(world_size, model, fsdp_config, temp_file_name, unused, test_case),
        nprocs=world_size,
        join=True,
    )


@skip_if_single_gpu
@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3), torch.rand(8, 3)], "assert_ref_out": False}])
@pytest.mark.parametrize("fsdp_config", [{}, {"flatten_parameters": False}])
@pytest.mark.parametrize("world_size", list(range(2, 9)))
def test_smaller_than_world_size(world_size, test_case, fsdp_config):
    """Test FSDP with uneven divide of parameter shards."""
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend")

    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs.")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    model = Sequential(
        Linear(3, 3, bias=False),
        Linear(3, 4, bias=False),
        Linear(4, 5, bias=False),
        Linear(5, 4, bias=False),
        Linear(4, 3, bias=False),
        Linear(3, 1, bias=False),
        Linear(1, 1, bias=False),  # param here is smaller than world_size if unflattened.
    )
    mp.spawn(
        _test_func,
        args=(world_size, model, fsdp_config, temp_file_name, unused, test_case),
        nprocs=world_size,
        join=True,
    )