test_fsdp_regnet.py 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with regnet-like model. """

12
import contextlib
13
from itertools import product
Min Xu's avatar
Min Xu committed
14
import random
15
16
17
18
import tempfile

import pytest
import torch
19
from torch.cuda.amp import GradScaler
20
import torch.multiprocessing as mp
21
22
23
24
25
26
27
28
29
30
31
32
from torch.nn import (
    AdaptiveAvgPool2d,
    BatchNorm2d,
    Conv2d,
    CrossEntropyLoss,
    Linear,
    Module,
    ReLU,
    Sequential,
    Sigmoid,
    SyncBatchNorm,
)
33
from torch.nn.parallel import DistributedDataParallel as DDP
34
35
36
from torch.optim import SGD

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
Min Xu's avatar
Min Xu committed
37
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
38
from fairscale.optim.grad_scaler import ShardedGradScaler
39
from fairscale.utils import torch_version
40
41
42
43
44
45
46
from fairscale.utils.testing import (
    dist_init,
    objects_are_equal,
    rmf,
    skip_if_single_gpu,
    state_dict_norm,
    teardown,
47
    torch_cuda_version,
48
49
)

50
51
52
53
54
55
# Const test params.
#   Reduce iterations to 1 for debugging.
#   Change world_size to 8 on beefy machines for better test coverage.
_world_size = 2
_iterations = 5

56
57
# Cover different ReLU flavors. Different workers may have different values since
# this is a file level global. This is intensional to cover different behaviors.
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
_relu_inplace = True
if random.randint(0, 1) == 0:
    _relu_inplace = False

# TODO (Min): test apex BN when available in the future.
try:
    import apex

    apex_bn_converter = apex.parallel.convert_syncbn_model
except ImportError:
    apex_bn_converter = None
pytorch_bn_converter = SyncBatchNorm.convert_sync_batchnorm  # type: ignore
_single_rank_pg = False


class ResBlock(Module):
    """Conv block in regnet with residual connection."""

    def __init__(self, width_in, width_out):
        super().__init__()
        self.proj = Conv2d(width_in, width_out, (1, 1), (2, 2), bias=False)
        self.bn = BatchNorm2d(width_out)
        self.f = Sequential(
            Sequential(  # block a
                Conv2d(width_in, width_out, (1, 1), (1, 1), bias=False), BatchNorm2d(width_out), ReLU(_relu_inplace),
            ),
            Sequential(  # block b
                Conv2d(width_out, width_out, (3, 3), (2, 2), (1, 1), groups=2, bias=False),
                BatchNorm2d(width_out),
                ReLU(_relu_inplace),
            ),
            Sequential(  # block se
                AdaptiveAvgPool2d((1, 1)),
                Sequential(
                    Conv2d(width_out, 2, (1, 1), (1, 1), bias=False),
                    ReLU(_relu_inplace),
                    Conv2d(2, width_out, (1, 1), (1, 1), bias=False),
                    Sigmoid(),
                ),
            ),
            Conv2d(width_out, width_out, (1, 1), (1, 1), bias=False),  # block c
            BatchNorm2d(width_out),  # final_bn
        )
        self.relu = ReLU()
        self.need_fsdp_wrap = True

    def forward(self, x):
        x = self.bn(self.proj(x)) + self.f(x)
        return self.relu(x)

108
109

class Model(Module):
110
111
    """SSL model with trunk and head."""

112
    def __init__(self, conv_bias, linear_bias):
113
        super().__init__()
114
        print(f"relu inplace: {_relu_inplace}, conv bias: {conv_bias}, linear bias: {linear_bias}")
115
116
117

        self.trunk = Sequential()
        self.trunk.need_fsdp_wrap = True  # Set a flag for later wrapping.
118
        stem = Sequential(Conv2d(2, 4, (3, 3), (2, 2), (1, 1), bias=conv_bias), BatchNorm2d(4), ReLU(_relu_inplace))
119
120
121
122
123
        any_stage_block1_0 = ResBlock(4, 8)
        self.trunk.add_module("stem", stem)
        self.trunk.add_module("any_stage_block1", Sequential(any_stage_block1_0))

        self.head = Sequential(
124
            Sequential(Linear(16, 16, bias=linear_bias), ReLU(), Linear(16, 8, bias=linear_bias)),  # projection_head
125
126
            Linear(8, 15, bias=False),  # prototypes0
        )
127
128

    def forward(self, x):
129
130
        x = self.trunk(x).reshape(-1)
        x = self.head(x)
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        return x


# We get a bit fancy here. Since the scope is `module`, this is run only
# once no matter how many tests variations for FSDP are requested to run
# to compare with the DDP reference. For example, a single DDP
# reference run is needed for both flatten and non-flatten param FSDP.
#
# Note, this runs DDP twice with and without mixed precision and asserts
# the resulting weights are different.
#
# This fixture captures and returns:
#
#   - model state_dict before training
#   - model data inputs
#   - model state_dict after training
@pytest.fixture(scope="module")
def ddp_ref():
149
150
151
152
153
154
155
156
157
    # Cover different bias flavors. Use random instead of parameterize them to reduce
    # the test runtime. Otherwise, we would have covered all cases exhaustively.
    conv_bias = True
    if random.randint(0, 1) == 0:
        conv_bias = False
    linear_bias = True
    if random.randint(0, 1) == 0:
        linear_bias = False

158
    # Get a reference model state
159
    model = Model(conv_bias, linear_bias)
160
    state_before = model.state_dict()
161

162
    # Get reference inputs per rank.
163
164
165
166
    world_size = _world_size
    iterations = _iterations
    print(f"Getting DDP reference for world_size {world_size} and iterations {iterations}")
    inputs = [[] for i in range(world_size)]
167
168
169
170
    for rank in range(world_size):
        for i in range(iterations):
            inputs[rank].append(torch.rand(2, 2, 2, 2))

171
172
173
    # Run reference DDP training 4 times, fp and mp, sync_bn or not.
    state_after = {}
    for precision, sync_bn in product(["full", "mixed"], ["none", "pytorch"]):
174
175
176
177
        temp_file_name = tempfile.mkstemp()[1]
        unused = tempfile.mkstemp()[1]
        rank_0_output = tempfile.mkstemp()[1]
        try:
Min Xu's avatar
Min Xu committed
178
            fsdp_config = None  # This means we use DDP in _distributed_worker.
179
            mp.spawn(
Min Xu's avatar
Min Xu committed
180
                _distributed_worker,
181
182
183
                args=(
                    world_size,
                    fsdp_config,
184
                    None,
185
186
187
188
189
190
191
                    precision == "mixed",
                    temp_file_name,
                    unused,
                    state_before,
                    inputs,
                    rank_0_output,
                    None,
192
193
194
                    sync_bn,
                    conv_bias,
                    linear_bias,
195
196
197
198
                ),
                nprocs=world_size,
                join=True,
            )
199
            state_after[(precision, sync_bn)] = torch.load(rank_0_output)
200
201
202
203
204
        finally:
            rmf(temp_file_name)
            rmf(unused)
            rmf(rank_0_output)

205
206
207
208
    # Sanity check DDP's final states.
    states = list(state_after.values())
    for state in states[1:]:
        assert state_dict_norm(states[0]) != state_dict_norm(state)
209

210
    return state_before, inputs, conv_bias, linear_bias, state_after
211
212
213
214
215
216
217


# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]
218

219
220
221
222
223
224
225
    yield temp_file_name, unused

    # temp files could have been removed, so we use rmf.
    rmf(temp_file_name)
    rmf(unused)


Min Xu's avatar
Min Xu committed
226
def _distributed_worker(
227
228
229
    rank,
    world_size,
    fsdp_config,
230
    fsdp_wrap_bn,
231
232
233
234
235
236
237
    ddp_mixed_precision,
    tempfile_name,
    unused,
    state_before,
    inputs,
    rank_0_output,
    state_after,
238
239
240
    sync_bn,
    conv_bias,
    linear_bias,
241
):
Min Xu's avatar
Min Xu committed
242
243
    torch.backends.cudnn.deterministic = True

244
245
246
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

247
248
249
250
    ddp = True
    if fsdp_config:
        ddp = False
        assert isinstance(fsdp_config, dict), str(fsdp_config)
251
252
253
        if fsdp_config["mixed_precision"]:
            # To match DDP in AMP -O1, we need fp32 reduce scatter.
            fsdp_config["fp32_reduce_scatter"] = True
254

255
    model = Model(conv_bias, linear_bias)
256
257
258
    model.load_state_dict(state_before)
    model = model.cuda()

259
260
261
262
263
264
265
266
267
268
269
    class DummyScaler:
        def scale(self, loss):
            return loss

        def step(self, optim):
            optim.step()

        def update(self):
            pass

    scaler = DummyScaler()
270
    if ddp:
271
272
        if sync_bn == "pytorch":
            model = pytorch_bn_converter(model)
273
274
275
        model = DDP(model, device_ids=[rank], broadcast_buffers=True)
        if ddp_mixed_precision:
            scaler = GradScaler()
Min Xu's avatar
Min Xu committed
276
    else:
277
278
279
        # Note, different rank may wrap in different order due to different random
        # seeds. But results should be the same.
        if random.randint(0, 1) == 0:
280
            print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}")
281
282
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
283
284
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
285
        else:
286
287
288
            print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}")
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
289
290
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
291
        model = FSDP(model, **fsdp_config).cuda()
292
293
294
295
296
        if fsdp_config["mixed_precision"]:
            scaler = ShardedGradScaler()
        # Print the model for verification.
        if rank == 0:
            print(model)
297
    optim = SGD(model.parameters(), lr=0.1)
298
    loss_func = CrossEntropyLoss()
299

300
301
302
303
304
305
    for in_data in inputs[rank]:
        in_data = in_data.cuda()
        context = contextlib.suppress()
        if ddp and ddp_mixed_precision:
            in_data = in_data.half()
            context = torch.cuda.amp.autocast(enabled=True)
306
307
        if not ddp and fsdp_config["mixed_precision"]:
            context = torch.cuda.amp.autocast(enabled=True)
308
309
        with context:
            out = model(in_data)
310
311
312
313
314
            fake_label = torch.zeros(1, dtype=torch.long).cuda()
            loss = loss_func(out.unsqueeze(0), fake_label)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
315
316
        optim.zero_grad()

317
318
319
320
321
322
323
324
325
326
327
328
    if ddp:
        # Save the rank 0 state_dict to the output file.
        if rank == 0:
            state_after = model.module.cpu().state_dict()
            torch.save(state_after, rank_0_output)
    else:
        model.assert_state(TrainingState.IDLE)
        # Ensure final state equals to the state_after.
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
329
330
331
332
333
334
335
336
337
        # Change False to True to enable this when you want to debug the mismatch.
        if False and rank == 0:

            def dump(d):
                for k, v in d.items():
                    print(k, v)

            dump(state_after)
            dump(fsdp_state)
338
339
340
341
        # If sync_bn is used, all ranks should have the same state, so we can compare with
        # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0.
        if sync_bn != "none" or rank == 0:
            assert objects_are_equal(state_after, fsdp_state, raise_exception=True)
342

343
344
345
    teardown()


346
# We use strings for precision and flatten params instead of bool to
347
348
349
350
# make the pytest output more readable.
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
351
352
@pytest.mark.parametrize("sync_bn", ["none", "pytorch"])
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn):
353
354
355
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

356
    state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref
357

358
    state_after = state_after[(precision, sync_bn)]
359
360
361
362
363

    fsdp_config = {}
    fsdp_config["mixed_precision"] = precision == "mixed"
    fsdp_config["flatten_parameters"] = flatten == "flatten"

364
365
366
367
368
    # When linear bias is True, DDP's AMP O1 and FSDP's default AMP O1.5 is different,
    # we force FSDP to use AMP O1 here by setting compute_dtype to float32.
    if linear_bias:
        fsdp_config["compute_dtype"] = torch.float32

369
370
371
    if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
        pytest.skip("Only CUDA 11 is supported with AMP equivalency")

372
    # Wrap BN half of the time.
373
374
375
    wrap_bn = True
    if random.randint(0, 1) == 0:
        wrap_bn = False
376
377
378
    # Except, always wrap BN in mixed precision + sync_bn mode, due to error of sync_bn wrapping,
    # regardless of compute_dtype.
    if fsdp_config["mixed_precision"] and sync_bn != "none":
379
380
        wrap_bn = True

381
382
383
384
385
386
    # When BN is not wrapped (i.e. not in full precision), FSDP's compute_dtype needs to
    # be fp32 to match DDP (otherwise, numerical errors happen on BN's running_mean/running_var
    # buffers).
    if fsdp_config["mixed_precision"] and not wrap_bn:
        fsdp_config["compute_dtype"] = torch.float32

387
    world_size = _world_size
388
    mp.spawn(
Min Xu's avatar
Min Xu committed
389
        _distributed_worker,
390
391
392
393
394
395
396
397
398
399
400
        args=(
            world_size,
            fsdp_config,
            wrap_bn,
            None,
            temp_files[0],
            temp_files[1],
            state_before,
            inputs,
            None,
            state_after,
401
402
403
            sync_bn,
            conv_bias,
            linear_bias,
404
        ),
405
406
        nprocs=world_size,
        join=True,
407
    )