test_fsdp_memory.py 14.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with GPU memory usage. """

12
import contextlib
13
14
15
16
17
18
19
20

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

21
22
23
from fair_dev.testing.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.internal import torch_version
from fairscale.internal.parallel import get_process_group_cached
24
25
26
27
28
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn


29
def to_fsdp(module, fsdp_config):
30
    return FSDP(module, process_group=get_process_group_cached(), **fsdp_config)
31
32
33
34
35
36
37
38


def get_cur_mem(rank, result, prefix):
    """Collect memory allocated values in a result dict in MB"""
    result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)


class Model(nn.Module):
39
    def __init__(self, hidden_dim):
40
        super().__init__()
41
42
43
44
        # TODO (Min): for both fast and memory efficient conv kernels, we should be using
        #     AMP/fp16 + channel_last input format. Otherwise, cudnn internally does conversion
        #     to channel_last when it is fp16 weights. Leave this knowledge here and perhaps
        #     future test can cover it.
45
46
        self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.blocks = nn.Sequential(
47
48
            nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
            nn.BatchNorm2d(hidden_dim),
49
            nn.ReLU(inplace=True),
50
51
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
            nn.BatchNorm2d(hidden_dim),
52
            nn.ReLU(inplace=True),
53
54
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
            nn.BatchNorm2d(hidden_dim),
55
56
57
58
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
59
        self.head = nn.Linear(hidden_dim, 10)
60
61
62
63
64

    def forward(self, x):
        return self.head(self.blocks(self.stem(x)))


65
66
def create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config):
    model = Model(model_hidden_dim)
67
68
69
70
71
    if with_fsdp:
        model.stem = auto_wrap_bn(model.stem, single_rank_pg=False)
        model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False)
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
72
73
74
        model.stem = to_fsdp(model.stem, fsdp_config)
        model.blocks = to_fsdp(model.blocks, fsdp_config)
        model.head = to_fsdp(model.head, fsdp_config)
75
76
77
78
79
80
    else:
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
    return model


81
82
83
def _distributed_worker(
    gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected, model_hidden_dim, fsdp_config
):
84
85
86
87
88
89
90
91
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
92
93

    # Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here.
94
95
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

96
    model = create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config)
97
98
    model = model.cuda()
    if with_fsdp:
99
        model = to_fsdp(model, fsdp_config)
100
101
102
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500)

103
104
    # We enable momentum so that after the first iteration, the optimizer state is added
    # to the total memory used.
105
    criterion = nn.MSELoss()
106
    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
107

108
109
110
111
112
113
114
115
116
117
118
    # Set AMP context if needed.
    context = contextlib.suppress()
    if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]:
        context = torch.cuda.amp.autocast(enabled=True)

    # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
    # test but on much bigger scale tests). We run 4 iterations here just in case it happens.
    iterations = 4

    results = {}  # results of memory stats
    for iteration in range(iterations):
119
120
        get_cur_mem(gpu_id, results, f"iter {iteration}: start")

121
122
123
        with context:
            out = model(batch)
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")
124

125
126
127
            out = sum(o.sum() for o in out[0])
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

        fake_loss.backward()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

        optimizer.step()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

        # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
        if torch_version() >= (1, 7, 0):
            model.zero_grad(set_to_none=True)
        else:
            for p in model.parameters():
                p.grad = None
        get_cur_mem(gpu_id, results, f"iter {iteration}: done")

143
144
145
146
147
148
149
150
151
152
153
154
155
156
    dump_all_tensors(gpu_id)
    print(results)

    def cmp(results, expected):
        ret = ""
        assert results.keys() == expected.keys(), f"{list(results.keys())} vs. {list(expected.keys())}"
        for k, v in results.items():
            exp = expected[k]
            if abs(exp - v) > 1:  # allow 1MB rounding differences
                ret += f"{k}: got {v}, expected {exp}\n"
        return ret

    output = cmp(results, expected)
    assert not output, output
157
158
159
160
161

    teardown()


@skip_if_single_gpu
162
@pytest.mark.timeout(120)
163
@pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"])
164
@pytest.mark.parametrize("fsdp", ["ddp", "fsdp", "fsdp_amp_default", "fsdp_amp_compute_dtype32"])
165
166
167
168
169
170
171
def test_fsdp_memory(fsdp, ckpt):
    expected = {
        ("ddp", "no_ckpt"): {
            "iter 0: start": 9,
            "iter 0: after fwd": 346,
            "iter 0: after loss": 346,
            "iter 0: after bwd": 14,
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            "iter 0: after step": 17,
            "iter 0: done": 13,
            "iter 1: start": 13,
            "iter 1: after fwd": 350,
            "iter 1: after loss": 350,
            "iter 1: after bwd": 17,
            "iter 1: after step": 17,
            "iter 1: done": 13,
            "iter 2: start": 13,
            "iter 2: after fwd": 350,
            "iter 2: after loss": 350,
            "iter 2: after bwd": 17,
            "iter 2: after step": 17,
            "iter 2: done": 13,
            "iter 3: start": 13,
            "iter 3: after fwd": 350,
            "iter 3: after loss": 350,
            "iter 3: after bwd": 17,
            "iter 3: after step": 17,
            "iter 3: done": 13,
192
193
194
195
196
        },
        ("fsdp", "no_ckpt"): {
            "iter 0: start": 3,
            "iter 0: after fwd": 340,
            "iter 0: after loss": 340,
197
198
            "iter 0: after bwd": 16,
            "iter 0: after step": 18,
199
200
201
202
            "iter 0: done": 5,
            "iter 1: start": 5,
            "iter 1: after fwd": 342,
            "iter 1: after loss": 342,
203
204
            "iter 1: after bwd": 18,
            "iter 1: after step": 18,
205
206
207
208
            "iter 1: done": 5,
            "iter 2: start": 5,
            "iter 2: after fwd": 342,
            "iter 2: after loss": 342,
209
210
            "iter 2: after bwd": 18,
            "iter 2: after step": 18,
211
212
213
214
            "iter 2: done": 5,
            "iter 3: start": 5,
            "iter 3: after fwd": 342,
            "iter 3: after loss": 342,
215
216
            "iter 3: after bwd": 18,
            "iter 3: after step": 18,
217
218
219
220
221
222
            "iter 3: done": 5,
        },
        ("fsdp_amp_default", "no_ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 630,
            "iter 0: after loss": 630,
223
224
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
225
226
227
228
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 657,
            "iter 1: after loss": 657,
229
230
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
231
232
233
234
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 657,
            "iter 2: after loss": 657,
235
236
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
237
238
239
240
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 657,
            "iter 3: after loss": 657,
241
242
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
243
244
245
246
247
248
            "iter 3: done": 54,
        },
        ("fsdp_amp_compute_dtype32", "no_ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 657,
            "iter 0: after loss": 657,
249
250
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
251
252
253
254
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 684,
            "iter 1: after loss": 684,
255
256
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
257
258
259
260
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 684,
            "iter 2: after loss": 684,
261
262
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
263
264
265
266
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 684,
            "iter 3: after loss": 684,
267
268
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
269
            "iter 3: done": 54,
270
271
272
273
274
275
        },
        ("ddp", "ckpt"): {
            "iter 0: start": 9,
            "iter 0: after fwd": 57,
            "iter 0: after loss": 57,
            "iter 0: after bwd": 14,
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
            "iter 0: after step": 17,
            "iter 0: done": 13,
            "iter 1: start": 13,
            "iter 1: after fwd": 61,
            "iter 1: after loss": 61,
            "iter 1: after bwd": 17,
            "iter 1: after step": 17,
            "iter 1: done": 13,
            "iter 2: start": 13,
            "iter 2: after fwd": 61,
            "iter 2: after loss": 61,
            "iter 2: after bwd": 17,
            "iter 2: after step": 17,
            "iter 2: done": 13,
            "iter 3: start": 13,
            "iter 3: after fwd": 61,
            "iter 3: after loss": 61,
            "iter 3: after bwd": 17,
            "iter 3: after step": 17,
            "iter 3: done": 13,
296
297
298
299
300
        },
        ("fsdp", "ckpt"): {
            "iter 0: start": 3,
            "iter 0: after fwd": 51,
            "iter 0: after loss": 51,
301
302
            "iter 0: after bwd": 16,
            "iter 0: after step": 18,
303
304
305
306
            "iter 0: done": 5,
            "iter 1: start": 5,
            "iter 1: after fwd": 53,
            "iter 1: after loss": 53,
307
308
            "iter 1: after bwd": 18,
            "iter 1: after step": 18,
309
310
311
312
            "iter 1: done": 5,
            "iter 2: start": 5,
            "iter 2: after fwd": 53,
            "iter 2: after loss": 53,
313
314
            "iter 2: after bwd": 18,
            "iter 2: after step": 18,
315
316
317
318
            "iter 2: done": 5,
            "iter 3: start": 5,
            "iter 3: after fwd": 53,
            "iter 3: after loss": 53,
319
320
            "iter 3: after bwd": 18,
            "iter 3: after step": 18,
321
322
323
324
325
326
            "iter 3: done": 5,
        },
        ("fsdp_amp_default", "ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 52,
            "iter 0: after loss": 52,
327
328
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
329
330
331
332
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 79,
            "iter 1: after loss": 79,
333
334
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
335
336
337
338
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 79,
            "iter 2: after loss": 79,
339
340
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
341
342
343
344
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 79,
            "iter 3: after loss": 79,
345
346
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
347
348
349
350
351
352
            "iter 3: done": 54,
        },
        ("fsdp_amp_compute_dtype32", "ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 52,
            "iter 0: after loss": 52,
353
354
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
355
356
357
358
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 79,
            "iter 1: after loss": 79,
359
360
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
361
362
363
364
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 79,
            "iter 2: after loss": 79,
365
366
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
367
368
369
370
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 79,
            "iter 3: after loss": 79,
371
372
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
373
            "iter 3: done": 54,
374
375
        },
    }[(fsdp, ckpt)]
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405

    # Compute the FSDP config.
    fsdp_config = {}

    # Set mixed precision.
    if "amp" in fsdp:
        fsdp_config["mixed_precision"] = True

    # When compute_dtype is FP32, make sure we use clear_autocast_cache.
    # Setting fp32_reduce_scatter and verbose for more code coverage.
    if "compute_dtype32" in fsdp:
        fsdp_config["compute_dtype"] = torch.float32
        fsdp_config["fp32_reduce_scatter"] = True
        fsdp_config["clear_autocast_cache"] = True
        fsdp_config["verbose"] = True

    # Using bigger hidden dimension for AMP to increase the model size
    # so that bug in handling params will show up but we don't do that
    # in the base case to keep the test fast.
    #   - hidden_dim 128: model size ~4MB
    #   - hidden_dim 512: model size ~55MB
    #   - hidden_dim 1024: model size ~200MB (seems to be too big for CI tests though)
    model_hidden_dim = 128
    if "amp" in fsdp:
        model_hidden_dim = 512

    # Get the fsdp and checkpoint flags.
    with_fsdp = "fsdp" in fsdp
    with_ckpt = ckpt == "ckpt"

406
407
408
    world_size = 2
    with temp_files_ctx(num=2) as temp_files:
        mp.spawn(
409
410
411
            _distributed_worker,
            (world_size, with_fsdp, with_ckpt, temp_files[0], temp_files[1], expected, model_hidden_dim, fsdp_config),
            nprocs=world_size,
412
        )