test_fsdp_freezing_weights.py 7.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with some params frozen. """


13
from enum import Enum
14
from itertools import product
15
16
17
18
19
20
21
22
23
24
25
26
27
import tempfile

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class FreezeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
        self.head = nn.Linear(64, 10)

        self.trunk = FSDP(self.trunk)

    def forward(self, x):
        return self.head(self.trunk(x))


def _freeze_distributed_worker(
    gpu_id, world_size, tempfile_name, unused,
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    # The use case for this test is where the weights in the submodule
    # are not frozen but the leftover weights or those contained by the
    # root module are frozen. Refer to issue #758 for a real world example.
    model = FreezeModel()
    model = model.cuda()

    for param in model.head.parameters():
        param.requires_grad = False

    model = FSDP(model)

    if gpu_id == 0:
        print(model)

    target = torch.tensor([0, 1], dtype=torch.long).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        optimizer.step()

    teardown()


@skip_if_single_gpu
def test_submodule_freezing_weights(temp_files):
    world_size = 2
    mp.spawn(
        _freeze_distributed_worker, (world_size, temp_files[0], temp_files[1]), nprocs=world_size,
    )


95
class Model(nn.Module):
96
    def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
97
98
99
100
101
102
103
104
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
        self.head = nn.Linear(64, 10)
105
106
107
108
109
110
        if with_fsdp and freeze_after_wrap_fsdp:
            self.fsdp_wrap()

    def fsdp_wrap(self):
        self.trunk = FSDP(self.trunk)
        self.head = FSDP(self.head)
111
112
113
114
115

    def forward(self, x):
        return self.head(self.trunk(x))


116
class NestedTrunkModel(nn.Module):
117
    def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
118
        super().__init__()
119
120
121
122
        self.trunk = nn.Sequential(
            self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp),
            self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp),
        )
123
        self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),)
124
125
126
127
128
129
130
131
132
        if with_fsdp and freeze_after_wrap_fsdp:
            self.fsdp_wrap()

    def fsdp_wrap(self):
        for name, child in self.trunk.named_children():
            wrapped_child = FSDP(child)
            setattr(self.trunk, name, wrapped_child)
        self.trunk = FSDP(self.trunk)
        self.head = FSDP(self.head)
133
134
135
136

    def forward(self, x):
        return self.head(self.trunk(x))

137
    def _create_block(self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp):
138
139
140
141
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),)
        return block


142
def _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp):
143
    if with_nested_trunk:
144
        model = NestedTrunkModel(with_fsdp, freeze_after_wrap_fsdp)
145
    else:
146
        model = Model(with_fsdp, freeze_after_wrap_fsdp)
147
148
149
    return model


150
151
152
153
154
class FreezingMethod(str, Enum):
    GradToNone = "grad_to_none"
    RequiresGrad = "requires_grad"


155
def _distributed_worker(
156
157
158
159
160
    gpu_id,
    world_size,
    with_fsdp,
    with_nested_trunk,
    freezing_method,
161
    freeze_after_wrap_fsdp,
162
163
164
165
    tempfile_name,
    unused,
    rank_0_output,
    expected_state,
166
167
168
169
170
171
172
173
174
175
176
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

177
    model = _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp)
178
179
180
    model = model.cuda()

    # freezing the trunk using requires_grad.
181
    if freezing_method == FreezingMethod.RequiresGrad:
182
183
184
185
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
186
187
        if not freeze_after_wrap_fsdp:
            model.fsdp_wrap()
188
189
190
191
192
193
194
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

195
    target = torch.tensor([0, 1], dtype=torch.long).cuda()
196
197
198
199
200
201
202
203
204
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
205
        if freezing_method == FreezingMethod.GradToNone:
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state, fsdp_state, raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()


# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
226
    num = 15  # 1 DDP and 4 FSDP cases each needs 3 files.
227
228
229
230
231
232
233
234
235
236
    files = [tempfile.mkstemp()[1] for _ in range(num)]

    yield tuple(files)

    # temp files could have been removed, so we use rmf.
    for name in files:
        rmf(name)


@skip_if_single_gpu
237
238
239
240
@pytest.mark.parametrize("nested_trunk", ["nested_trunk", "simple_trunk"])
def test_freezing_weights(temp_files, nested_trunk):
    with_nested_trunk = nested_trunk == "nested_trunk"

241
242
    world_size = 2
    # DDP
243
244
245
246
    with_fsdp = False
    freezing_method = FreezingMethod.RequiresGrad
    mp.spawn(
        _distributed_worker,
247
        (world_size, with_fsdp, with_nested_trunk, freezing_method, True) + temp_files[0:3] + (None,),
248
249
        nprocs=world_size,
    )
250
    # FSDP, case 1 and 2.
251
    with_fsdp = True
252
253
    expected_state = torch.load(temp_files[2])
    temp_file_idx = 3
254
255
256
    for freezing_method, freeze_after_wrap_fsdp in product(
        [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone], [True, False]
    ):
257
258
259
        print(f"Testing FSDP with freezing method {freezing_method}")
        mp.spawn(
            _distributed_worker,
260
            (world_size, with_fsdp, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp)
261
262
            + temp_files[temp_file_idx : temp_file_idx + 3]
            + (expected_state,),
263
264
265
            nprocs=world_size,
        )
        temp_file_idx += 3