test_fsdp_freezing_weights.py 6.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with some params frozen. """


13
from enum import Enum
14
from itertools import product
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tempfile

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown


class Model(nn.Module):
29
    def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
30
31
32
33
34
35
36
37
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
        self.head = nn.Linear(64, 10)
38
39
40
41
42
43
        if with_fsdp and freeze_after_wrap_fsdp:
            self.fsdp_wrap()

    def fsdp_wrap(self):
        self.trunk = FSDP(self.trunk)
        self.head = FSDP(self.head)
44
45
46
47
48

    def forward(self, x):
        return self.head(self.trunk(x))


49
class NestedTrunkModel(nn.Module):
50
    def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
51
        super().__init__()
52
53
54
55
        self.trunk = nn.Sequential(
            self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp),
            self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp),
        )
56
        self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),)
57
58
59
60
61
62
63
64
65
        if with_fsdp and freeze_after_wrap_fsdp:
            self.fsdp_wrap()

    def fsdp_wrap(self):
        for name, child in self.trunk.named_children():
            wrapped_child = FSDP(child)
            setattr(self.trunk, name, wrapped_child)
        self.trunk = FSDP(self.trunk)
        self.head = FSDP(self.head)
66
67
68
69

    def forward(self, x):
        return self.head(self.trunk(x))

70
    def _create_block(self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp):
71
72
73
74
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),)
        return block


75
def _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp):
76
    if with_nested_trunk:
77
        model = NestedTrunkModel(with_fsdp, freeze_after_wrap_fsdp)
78
    else:
79
        model = Model(with_fsdp, freeze_after_wrap_fsdp)
80
81
82
    return model


83
84
85
86
87
class FreezingMethod(str, Enum):
    GradToNone = "grad_to_none"
    RequiresGrad = "requires_grad"


88
def _distributed_worker(
89
90
91
92
93
    gpu_id,
    world_size,
    with_fsdp,
    with_nested_trunk,
    freezing_method,
94
    freeze_after_wrap_fsdp,
95
96
97
98
    tempfile_name,
    unused,
    rank_0_output,
    expected_state,
99
100
101
102
103
104
105
106
107
108
109
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

110
    model = _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp)
111
112
113
    model = model.cuda()

    # freezing the trunk using requires_grad.
114
    if freezing_method == FreezingMethod.RequiresGrad:
115
116
117
118
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
119
120
        if not freeze_after_wrap_fsdp:
            model.fsdp_wrap()
121
122
123
124
125
126
127
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

128
    target = torch.tensor([0, 1], dtype=torch.long).cuda()
129
130
131
132
133
134
135
136
137
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
138
        if freezing_method == FreezingMethod.GradToNone:
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state, fsdp_state, raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()


# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
159
    num = 15  # 1 DDP and 4 FSDP cases each needs 3 files.
160
161
162
163
164
165
166
167
168
169
    files = [tempfile.mkstemp()[1] for _ in range(num)]

    yield tuple(files)

    # temp files could have been removed, so we use rmf.
    for name in files:
        rmf(name)


@skip_if_single_gpu
170
171
172
173
@pytest.mark.parametrize("nested_trunk", ["nested_trunk", "simple_trunk"])
def test_freezing_weights(temp_files, nested_trunk):
    with_nested_trunk = nested_trunk == "nested_trunk"

174
175
    world_size = 2
    # DDP
176
177
178
179
    with_fsdp = False
    freezing_method = FreezingMethod.RequiresGrad
    mp.spawn(
        _distributed_worker,
180
        (world_size, with_fsdp, with_nested_trunk, freezing_method, True) + temp_files[0:3] + (None,),
181
182
        nprocs=world_size,
    )
183
    # FSDP, case 1 and 2.
184
    with_fsdp = True
185
186
    expected_state = torch.load(temp_files[2])
    temp_file_idx = 3
187
188
189
    for freezing_method, freeze_after_wrap_fsdp in product(
        [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone], [True, False]
    ):
190
191
192
        print(f"Testing FSDP with freezing method {freezing_method}")
        mp.spawn(
            _distributed_worker,
193
            (world_size, with_fsdp, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp)
194
195
            + temp_files[temp_file_idx : temp_file_idx + 3]
            + (expected_state,),
196
197
198
            nprocs=world_size,
        )
        temp_file_idx += 3