"tests/test_runner/test_dist_utils.py" did not exist on "6f674c7ea6a307a62c0ea74cc60e55570de31f6b"
test_fsdp_freezing_weights.py 4.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with some params frozen. """


import tempfile

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
        self.head = nn.Linear(64, 10)

    def forward(self, x):
        return self.head(self.trunk(x))


def _create_model(with_fsdp):
    model = Model()
    if with_fsdp:
        model.trunk = FSDP(model.trunk)
        model.head = FSDP(model.head)
    return model


def _distributed_worker(
    gpu_id, world_size, with_fsdp, freezing_method, tempfile_name, unused, rank_0_output, expected_state
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = _create_model(with_fsdp)
    model = model.cuda()

    # freezing the trunk using requires_grad.
    assert freezing_method in ["requires_grad", "grad_to_none"]
    if freezing_method == "requires_grad":
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.LongTensor([0, 1]).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        if freezing_method == "grad_to_none":
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state, fsdp_state, raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()


# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
    num = 9  # 1 DDP and 2 FSDP cases each needs 3 files.
    files = [tempfile.mkstemp()[1] for _ in range(num)]

    yield tuple(files)

    # temp files could have been removed, so we use rmf.
    for name in files:
        rmf(name)


@skip_if_single_gpu
def tests1(temp_files):
    world_size = 2
    # DDP
    fsdp = False
    freezing_method = "requires_grad"
    mp.spawn(_distributed_worker, (world_size, fsdp, freezing_method) + temp_files[0:3] + (None,), nprocs=world_size)
    # FSDP, case 1 and 2.
    fsdp = True
    expected_state = torch.load(temp_files[2])
    temp_file_idx = 3
    for freezing_method in ["requires_grad", "grad_to_none"]:
        print(f"Testing FSDP with freezing method {freezing_method}")
        mp.spawn(
            _distributed_worker,
            (world_size, fsdp, freezing_method) + temp_files[temp_file_idx : temp_file_idx + 3] + (expected_state,),
            nprocs=world_size,
        )
        temp_file_idx += 3