test_fsdp_freezing_weights.py 7.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with some params frozen. """


13
from enum import Enum
14
from itertools import product
15
16
17
18
19
20
21
22
23
import tempfile

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

24
from fair_dev.testing.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown
25
26
27
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class FreezeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
        self.head = nn.Linear(64, 10)

        self.trunk = FSDP(self.trunk)

    def forward(self, x):
        return self.head(self.trunk(x))


def _freeze_distributed_worker(
46
47
48
49
    gpu_id,
    world_size,
    tempfile_name,
    unused,
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    # The use case for this test is where the weights in the submodule
    # are not frozen but the leftover weights or those contained by the
    # root module are frozen. Refer to issue #758 for a real world example.
    model = FreezeModel()
    model = model.cuda()

    for param in model.head.parameters():
        param.requires_grad = False

    model = FSDP(model)

    if gpu_id == 0:
        print(model)

    target = torch.tensor([0, 1], dtype=torch.long).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        optimizer.step()

    teardown()


@skip_if_single_gpu
def test_submodule_freezing_weights(temp_files):
    world_size = 2
    mp.spawn(
94
95
96
        _freeze_distributed_worker,
        (world_size, temp_files[0], temp_files[1]),
        nprocs=world_size,
97
98
99
    )


100
class Model(nn.Module):
101
    def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
102
103
104
105
106
107
108
109
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
        self.head = nn.Linear(64, 10)
110
111
112
113
114
115
        if with_fsdp and freeze_after_wrap_fsdp:
            self.fsdp_wrap()

    def fsdp_wrap(self):
        self.trunk = FSDP(self.trunk)
        self.head = FSDP(self.head)
116
117
118
119
120

    def forward(self, x):
        return self.head(self.trunk(x))


121
class NestedTrunkModel(nn.Module):
122
    def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
123
        super().__init__()
124
125
126
127
        self.trunk = nn.Sequential(
            self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp),
            self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp),
        )
128
129
130
131
132
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
            nn.Linear(64, 10),
        )
133
134
135
136
137
138
139
140
141
        if with_fsdp and freeze_after_wrap_fsdp:
            self.fsdp_wrap()

    def fsdp_wrap(self):
        for name, child in self.trunk.named_children():
            wrapped_child = FSDP(child)
            setattr(self.trunk, name, wrapped_child)
        self.trunk = FSDP(self.trunk)
        self.head = FSDP(self.head)
142
143
144
145

    def forward(self, x):
        return self.head(self.trunk(x))

146
    def _create_block(self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp):
147
148
149
150
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
        )
151
152
153
        return block


154
def _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp):
155
    if with_nested_trunk:
156
        model = NestedTrunkModel(with_fsdp, freeze_after_wrap_fsdp)
157
    else:
158
        model = Model(with_fsdp, freeze_after_wrap_fsdp)
159
160
161
    return model


162
163
164
165
166
class FreezingMethod(str, Enum):
    GradToNone = "grad_to_none"
    RequiresGrad = "requires_grad"


167
def _distributed_worker(
168
169
170
171
172
    gpu_id,
    world_size,
    with_fsdp,
    with_nested_trunk,
    freezing_method,
173
    freeze_after_wrap_fsdp,
174
175
176
177
    tempfile_name,
    unused,
    rank_0_output,
    expected_state,
178
179
180
181
182
183
184
185
186
187
188
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

189
    model = _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp)
190
191
192
    model = model.cuda()

    # freezing the trunk using requires_grad.
193
    if freezing_method == FreezingMethod.RequiresGrad:
194
195
196
197
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
198
199
        if not freeze_after_wrap_fsdp:
            model.fsdp_wrap()
200
201
202
203
204
205
206
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

207
    target = torch.tensor([0, 1], dtype=torch.long).cuda()
208
209
210
211
212
213
214
215
216
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
217
        if freezing_method == FreezingMethod.GradToNone:
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state, fsdp_state, raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()


# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
238
    num = 15  # 1 DDP and 4 FSDP cases each needs 3 files.
239
240
241
242
243
244
245
246
247
248
    files = [tempfile.mkstemp()[1] for _ in range(num)]

    yield tuple(files)

    # temp files could have been removed, so we use rmf.
    for name in files:
        rmf(name)


@skip_if_single_gpu
249
250
251
252
@pytest.mark.parametrize("nested_trunk", ["nested_trunk", "simple_trunk"])
def test_freezing_weights(temp_files, nested_trunk):
    with_nested_trunk = nested_trunk == "nested_trunk"

253
254
    world_size = 2
    # DDP
255
256
257
258
    with_fsdp = False
    freezing_method = FreezingMethod.RequiresGrad
    mp.spawn(
        _distributed_worker,
259
        (world_size, with_fsdp, with_nested_trunk, freezing_method, True) + temp_files[0:3] + (None,),
260
261
        nprocs=world_size,
    )
262
    # FSDP, case 1 and 2.
263
    with_fsdp = True
264
265
    expected_state = torch.load(temp_files[2])
    temp_file_idx = 3
266
267
268
    for freezing_method, freeze_after_wrap_fsdp in product(
        [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone], [True, False]
    ):
269
270
271
        print(f"Testing FSDP with freezing method {freezing_method}")
        mp.spawn(
            _distributed_worker,
272
            (world_size, with_fsdp, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp)
273
274
            + temp_files[temp_file_idx : temp_file_idx + 3]
            + (expected_state,),
275
276
277
            nprocs=world_size,
        )
        temp_file_idx += 3