test_fsdp_regnet.py 11.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with regnet-like model. """

12
import contextlib
Min Xu's avatar
Min Xu committed
13
import random
14
15
16
17
import tempfile

import pytest
import torch
18
from torch.cuda.amp import GradScaler
19
import torch.multiprocessing as mp
20
21
22
23
24
25
26
27
28
29
30
31
from torch.nn import (
    AdaptiveAvgPool2d,
    BatchNorm2d,
    Conv2d,
    CrossEntropyLoss,
    Linear,
    Module,
    ReLU,
    Sequential,
    Sigmoid,
    SyncBatchNorm,
)
32
from torch.nn.parallel import DistributedDataParallel as DDP
33
34
35
from torch.optim import SGD

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
Min Xu's avatar
Min Xu committed
36
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
37
from fairscale.optim.grad_scaler import ShardedGradScaler
38
39
40
41
42
43
44
from fairscale.utils.testing import (
    dist_init,
    objects_are_equal,
    rmf,
    skip_if_single_gpu,
    state_dict_norm,
    teardown,
45
    torch_cuda_version,
46
47
48
    torch_version,
)

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Const test params.
#   Reduce iterations to 1 for debugging.
#   Change world_size to 8 on beefy machines for better test coverage.
_world_size = 2
_iterations = 5

# Cover different ReLU flavor. This will cause DDP and FSDP models to have
# different ReLUs since they will different random flags.
_relu_inplace = True
if random.randint(0, 1) == 0:
    _relu_inplace = False

# TODO (Min): test apex BN when available in the future.
try:
    import apex

    apex_bn_converter = apex.parallel.convert_syncbn_model
except ImportError:
    apex_bn_converter = None
pytorch_bn_converter = SyncBatchNorm.convert_sync_batchnorm  # type: ignore
_bn_converter = pytorch_bn_converter
_single_rank_pg = False


class ResBlock(Module):
    """Conv block in regnet with residual connection."""

    def __init__(self, width_in, width_out):
        super().__init__()
        self.proj = Conv2d(width_in, width_out, (1, 1), (2, 2), bias=False)
        self.bn = BatchNorm2d(width_out)
        self.f = Sequential(
            Sequential(  # block a
                Conv2d(width_in, width_out, (1, 1), (1, 1), bias=False), BatchNorm2d(width_out), ReLU(_relu_inplace),
            ),
            Sequential(  # block b
                Conv2d(width_out, width_out, (3, 3), (2, 2), (1, 1), groups=2, bias=False),
                BatchNorm2d(width_out),
                ReLU(_relu_inplace),
            ),
            Sequential(  # block se
                AdaptiveAvgPool2d((1, 1)),
                Sequential(
                    Conv2d(width_out, 2, (1, 1), (1, 1), bias=False),
                    ReLU(_relu_inplace),
                    Conv2d(2, width_out, (1, 1), (1, 1), bias=False),
                    Sigmoid(),
                ),
            ),
            Conv2d(width_out, width_out, (1, 1), (1, 1), bias=False),  # block c
            BatchNorm2d(width_out),  # final_bn
        )
        self.relu = ReLU()
        self.need_fsdp_wrap = True

    def forward(self, x):
        x = self.bn(self.proj(x)) + self.f(x)
        return self.relu(x)

108
109

class Model(Module):
110
111
    """SSL model with trunk and head."""

112
113
    def __init__(self):
        super().__init__()
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        print(f"Using relu inplace: {_relu_inplace}")

        self.trunk = Sequential()
        self.trunk.need_fsdp_wrap = True  # Set a flag for later wrapping.
        stem = Sequential(Conv2d(2, 4, (3, 3), (2, 2), (1, 1), bias=False), BatchNorm2d(4), ReLU(_relu_inplace))
        any_stage_block1_0 = ResBlock(4, 8)
        self.trunk.add_module("stem", stem)
        self.trunk.add_module("any_stage_block1", Sequential(any_stage_block1_0))

        self.head = Sequential(
            # TODO (Min): FSDP-mixed_precision doesn't compute the same ways as DDP AMP when bias=True.
            #     so, we use bias=False for now in the projection_head.
            #     The Conv2d layers above does not use bias in regnet, but even if they use
            #     bias, FSDP and DDP seem to agree on how it is computed.
            Sequential(Linear(16, 16, bias=False), ReLU(), Linear(16, 8, bias=False),),  # projection_head
            Linear(8, 15, bias=False),  # prototypes0
        )
131
132

    def forward(self, x):
133
134
        x = self.trunk(x).reshape(-1)
        x = self.head(x)
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        return x


# We get a bit fancy here. Since the scope is `module`, this is run only
# once no matter how many tests variations for FSDP are requested to run
# to compare with the DDP reference. For example, a single DDP
# reference run is needed for both flatten and non-flatten param FSDP.
#
# Note, this runs DDP twice with and without mixed precision and asserts
# the resulting weights are different.
#
# This fixture captures and returns:
#
#   - model state_dict before training
#   - model data inputs
#   - model state_dict after training
@pytest.fixture(scope="module")
def ddp_ref():
    # Get a reference model state
    model = Model()
    state_before = model.state_dict()
156

157
    # Get reference inputs per rank.
158
159
160
161
    world_size = _world_size
    iterations = _iterations
    print(f"Getting DDP reference for world_size {world_size} and iterations {iterations}")
    inputs = [[] for i in range(world_size)]
162
163
164
165
166
167
168
169
170
171
    for rank in range(world_size):
        for i in range(iterations):
            inputs[rank].append(torch.rand(2, 2, 2, 2))

    # Run DDP training twice, fp and mp.
    for precision in ["full", "mixed"]:
        temp_file_name = tempfile.mkstemp()[1]
        unused = tempfile.mkstemp()[1]
        rank_0_output = tempfile.mkstemp()[1]
        try:
Min Xu's avatar
Min Xu committed
172
            fsdp_config = None  # This means we use DDP in _distributed_worker.
173
            mp.spawn(
Min Xu's avatar
Min Xu committed
174
                _distributed_worker,
175
176
177
                args=(
                    world_size,
                    fsdp_config,
178
                    None,
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                    precision == "mixed",
                    temp_file_name,
                    unused,
                    state_before,
                    inputs,
                    rank_0_output,
                    None,
                ),
                nprocs=world_size,
                join=True,
            )
            if precision == "full":
                state_after_fp = torch.load(rank_0_output)
            else:
                state_after_mp = torch.load(rank_0_output)
        finally:
            rmf(temp_file_name)
            rmf(unused)
            rmf(rank_0_output)

    assert state_dict_norm(state_after_fp) != state_dict_norm(state_after_mp)

    return state_before, inputs, state_after_fp, state_after_mp


# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]
209

210
211
212
213
214
215
216
    yield temp_file_name, unused

    # temp files could have been removed, so we use rmf.
    rmf(temp_file_name)
    rmf(unused)


Min Xu's avatar
Min Xu committed
217
def _distributed_worker(
218
219
220
    rank,
    world_size,
    fsdp_config,
221
    fsdp_wrap_bn,
222
223
224
225
226
227
228
229
    ddp_mixed_precision,
    tempfile_name,
    unused,
    state_before,
    inputs,
    rank_0_output,
    state_after,
):
Min Xu's avatar
Min Xu committed
230
231
    torch.backends.cudnn.deterministic = True

232
233
234
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

235
236
237
238
    ddp = True
    if fsdp_config:
        ddp = False
        assert isinstance(fsdp_config, dict), str(fsdp_config)
239
240
241
        if fsdp_config["mixed_precision"]:
            # To match DDP in AMP -O1, we need fp32 reduce scatter.
            fsdp_config["fp32_reduce_scatter"] = True
242
243

    model = Model()
244
245
246
    model.load_state_dict(state_before)
    model = model.cuda()

247
248
249
250
251
252
253
254
255
256
257
    class DummyScaler:
        def scale(self, loss):
            return loss

        def step(self, optim):
            optim.step()

        def update(self):
            pass

    scaler = DummyScaler()
258
    if ddp:
Min Xu's avatar
Min Xu committed
259
        model = SyncBatchNorm.convert_sync_batchnorm(model)
260
261
262
        model = DDP(model, device_ids=[rank], broadcast_buffers=True)
        if ddp_mixed_precision:
            scaler = GradScaler()
Min Xu's avatar
Min Xu committed
263
    else:
264
265
266
        # Note, different rank may wrap in different order due to different random
        # seeds. But results should be the same.
        if random.randint(0, 1) == 0:
267
268
269
270
            print(f"auto_wrap_bn {fsdp_wrap_bn}, then convert_sync_batchnorm")
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
            model = _bn_converter(model)
271
        else:
272
273
274
275
            print(f"convert_sync_batchnorm, then auto_wrap_bn {fsdp_wrap_bn}")
            model = _bn_converter(model)
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
276
        model = FSDP(model, **fsdp_config).cuda()
277
278
279
280
281
        if fsdp_config["mixed_precision"]:
            scaler = ShardedGradScaler()
        # Print the model for verification.
        if rank == 0:
            print(model)
282
    optim = SGD(model.parameters(), lr=0.1)
283
    loss_func = CrossEntropyLoss()
284

285
286
287
288
289
290
    for in_data in inputs[rank]:
        in_data = in_data.cuda()
        context = contextlib.suppress()
        if ddp and ddp_mixed_precision:
            in_data = in_data.half()
            context = torch.cuda.amp.autocast(enabled=True)
291
292
        if not ddp and fsdp_config["mixed_precision"]:
            context = torch.cuda.amp.autocast(enabled=True)
293
294
        with context:
            out = model(in_data)
295
296
297
298
299
            fake_label = torch.zeros(1, dtype=torch.long).cuda()
            loss = loss_func(out.unsqueeze(0), fake_label)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
300
301
        optim.zero_grad()

302
303
304
305
306
307
308
309
310
311
312
313
    if ddp:
        # Save the rank 0 state_dict to the output file.
        if rank == 0:
            state_after = model.module.cpu().state_dict()
            torch.save(state_after, rank_0_output)
    else:
        model.assert_state(TrainingState.IDLE)
        # Ensure final state equals to the state_after.
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
314
315
316
317
318
319
320
321
322
        # Change False to True to enable this when you want to debug the mismatch.
        if False and rank == 0:

            def dump(d):
                for k, v in d.items():
                    print(k, v)

            dump(state_after)
            dump(fsdp_state)
323
324
        assert objects_are_equal(state_after, fsdp_state, raise_exception=True)

325
326
327
    teardown()


328
# We use strings for precision and flatten params instead of bool to
329
330
331
332
# make the pytest output more readable.
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
Min Xu's avatar
Min Xu committed
333
def test_regnet(temp_files, ddp_ref, precision, flatten):
334
335
336
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

337
338
339
340
341
342
    state_before, inputs, state_after_fp, state_after_mp = ddp_ref

    if precision == "full":
        state_after = state_after_fp
    else:
        state_after = state_after_mp
343
344
345
346
347

    fsdp_config = {}
    fsdp_config["mixed_precision"] = precision == "mixed"
    fsdp_config["flatten_parameters"] = flatten == "flatten"

348
349
350
351
352
353
354
355
356
357
358
359
    if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
        pytest.skip("Only CUDA 11 is supported with AMP equivalency")

    # Wrap BN half of the time in full precision mode.
    wrap_bn = True
    if random.randint(0, 1) == 0:
        wrap_bn = False
    # Always wrap BN in mixed precision mode.
    if fsdp_config["mixed_precision"]:
        wrap_bn = True

    world_size = _world_size
360
    mp.spawn(
Min Xu's avatar
Min Xu committed
361
        _distributed_worker,
362
363
364
365
366
367
368
369
370
371
372
373
        args=(
            world_size,
            fsdp_config,
            wrap_bn,
            None,
            temp_files[0],
            temp_files[1],
            state_before,
            inputs,
            None,
            state_after,
        ),
374
375
        nprocs=world_size,
        join=True,
376
    )