test_fsdp_freezing_weights.py 5.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with some params frozen. """


13
from enum import Enum
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import tempfile

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
        self.head = nn.Linear(64, 10)

    def forward(self, x):
        return self.head(self.trunk(x))


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class NestedTrunkModel(nn.Module):
    def __init__(self, with_fsdp):
        super().__init__()
        self.trunk = nn.Sequential(self._create_block(3, 64, with_fsdp), self._create_block(64, 64, with_fsdp),)
        self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),)
        if with_fsdp:
            self.trunk = FSDP(self.trunk)
            self.head = FSDP(self.head)

    def forward(self, x):
        return self.head(self.trunk(x))

    def _create_block(self, in_channels, out_channels, with_fsdp):
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),)
        if with_fsdp:
            block = FSDP(block)
        return block


def _create_model(with_fsdp, with_nested_trunk):
    if with_nested_trunk:
        model = NestedTrunkModel(with_fsdp)
    else:
        model = Model()
        if with_fsdp:
            model.trunk = FSDP(model.trunk)
            model.head = FSDP(model.head)
69
70
71
    return model


72
73
74
75
76
class FreezingMethod(str, Enum):
    GradToNone = "grad_to_none"
    RequiresGrad = "requires_grad"


77
def _distributed_worker(
78
79
80
81
82
83
84
85
86
    gpu_id,
    world_size,
    with_fsdp,
    with_nested_trunk,
    freezing_method,
    tempfile_name,
    unused,
    rank_0_output,
    expected_state,
87
88
89
90
91
92
93
94
95
96
97
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

98
    model = _create_model(with_fsdp, with_nested_trunk)
99
100
101
    model = model.cuda()

    # freezing the trunk using requires_grad.
102
    if freezing_method == FreezingMethod.RequiresGrad:
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.LongTensor([0, 1]).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
124
        if freezing_method == FreezingMethod.GradToNone:
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state, fsdp_state, raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()


# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
    num = 9  # 1 DDP and 2 FSDP cases each needs 3 files.
    files = [tempfile.mkstemp()[1] for _ in range(num)]

    yield tuple(files)

    # temp files could have been removed, so we use rmf.
    for name in files:
        rmf(name)


@skip_if_single_gpu
156
157
158
159
@pytest.mark.parametrize("nested_trunk", ["nested_trunk", "simple_trunk"])
def test_freezing_weights(temp_files, nested_trunk):
    with_nested_trunk = nested_trunk == "nested_trunk"

160
161
    world_size = 2
    # DDP
162
163
164
165
166
167
168
    with_fsdp = False
    freezing_method = FreezingMethod.RequiresGrad
    mp.spawn(
        _distributed_worker,
        (world_size, with_fsdp, with_nested_trunk, freezing_method) + temp_files[0:3] + (None,),
        nprocs=world_size,
    )
169
    # FSDP, case 1 and 2.
170
    with_fsdp = True
171
172
    expected_state = torch.load(temp_files[2])
    temp_file_idx = 3
173
    for freezing_method in [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]:
174
175
176
        print(f"Testing FSDP with freezing method {freezing_method}")
        mp.spawn(
            _distributed_worker,
177
178
179
            (world_size, with_fsdp, with_nested_trunk, freezing_method)
            + temp_files[temp_file_idx : temp_file_idx + 3]
            + (expected_state,),
180
181
182
            nprocs=world_size,
        )
        temp_file_idx += 3