test_fsdp_memory.py 14.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with GPU memory usage. """

12
import contextlib
13
14
15
16
17
18
19
20

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

21
from fairscale.fair_dev.testing.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx
22
23
from fairscale.internal import torch_version
from fairscale.internal.parallel import get_process_group_cached
24
25
26
27
28
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn


29
def to_fsdp(module, fsdp_config):
30
    return FSDP(module, process_group=get_process_group_cached(), **fsdp_config)
31
32
33
34
35
36
37
38


def get_cur_mem(rank, result, prefix):
    """Collect memory allocated values in a result dict in MB"""
    result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)


class Model(nn.Module):
39
    def __init__(self, hidden_dim):
40
        super().__init__()
41
42
43
44
        # TODO (Min): for both fast and memory efficient conv kernels, we should be using
        #     AMP/fp16 + channel_last input format. Otherwise, cudnn internally does conversion
        #     to channel_last when it is fp16 weights. Leave this knowledge here and perhaps
        #     future test can cover it.
45
46
        self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.blocks = nn.Sequential(
47
48
            nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
            nn.BatchNorm2d(hidden_dim),
49
            nn.ReLU(inplace=True),
50
51
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
            nn.BatchNorm2d(hidden_dim),
52
            nn.ReLU(inplace=True),
53
54
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
            nn.BatchNorm2d(hidden_dim),
55
56
57
58
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
59
        self.head = nn.Linear(hidden_dim, 10)
60
61
62
63
64

    def forward(self, x):
        return self.head(self.blocks(self.stem(x)))


65
66
def create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config):
    model = Model(model_hidden_dim)
67
68
69
70
71
    if with_fsdp:
        model.stem = auto_wrap_bn(model.stem, single_rank_pg=False)
        model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False)
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
72
73
74
        model.stem = to_fsdp(model.stem, fsdp_config)
        model.blocks = to_fsdp(model.blocks, fsdp_config)
        model.head = to_fsdp(model.head, fsdp_config)
75
76
77
78
79
80
    else:
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
    return model


81
82
83
def _distributed_worker(
    gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected, model_hidden_dim, fsdp_config
):
84
85
86
87
88
89
90
91
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
92
93

    # Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here.
94
95
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

96
    model = create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config)
97
98
    model = model.cuda()
    if with_fsdp:
99
        model = to_fsdp(model, fsdp_config)
100
101
102
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500)

103
104
    # We enable momentum so that after the first iteration, the optimizer state is added
    # to the total memory used.
105
    criterion = nn.MSELoss()
106
    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
107

108
109
110
111
112
113
114
115
116
117
118
    # Set AMP context if needed.
    context = contextlib.suppress()
    if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]:
        context = torch.cuda.amp.autocast(enabled=True)

    # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
    # test but on much bigger scale tests). We run 4 iterations here just in case it happens.
    iterations = 4

    results = {}  # results of memory stats
    for iteration in range(iterations):
119
120
        get_cur_mem(gpu_id, results, f"iter {iteration}: start")

121
122
123
        with context:
            out = model(batch)
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")
124

125
126
127
            out = sum(o.sum() for o in out[0])
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

        fake_loss.backward()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

        optimizer.step()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

        # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
        if torch_version() >= (1, 7, 0):
            model.zero_grad(set_to_none=True)
        else:
            for p in model.parameters():
                p.grad = None
        get_cur_mem(gpu_id, results, f"iter {iteration}: done")

143
144
145
146
147
148
149
150
151
152
153
154
155
156
    dump_all_tensors(gpu_id)
    print(results)

    def cmp(results, expected):
        ret = ""
        assert results.keys() == expected.keys(), f"{list(results.keys())} vs. {list(expected.keys())}"
        for k, v in results.items():
            exp = expected[k]
            if abs(exp - v) > 1:  # allow 1MB rounding differences
                ret += f"{k}: got {v}, expected {exp}\n"
        return ret

    output = cmp(results, expected)
    assert not output, output
157
158
159
160
161

    teardown()


@skip_if_single_gpu
162
@pytest.mark.timeout(120)
163
@pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"])
164
@pytest.mark.parametrize("fsdp", ["ddp", "fsdp", "fsdp_amp_default", "fsdp_amp_compute_dtype32"])
165
166
167
168
@pytest.mark.skipif(
    torch_version() >= (1, 14, 0),
    reason="Tests broke in Pytorch pre-release version 1.14",
)
169
170
171
172
173
174
175
def test_fsdp_memory(fsdp, ckpt):
    expected = {
        ("ddp", "no_ckpt"): {
            "iter 0: start": 9,
            "iter 0: after fwd": 346,
            "iter 0: after loss": 346,
            "iter 0: after bwd": 14,
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            "iter 0: after step": 17,
            "iter 0: done": 13,
            "iter 1: start": 13,
            "iter 1: after fwd": 350,
            "iter 1: after loss": 350,
            "iter 1: after bwd": 17,
            "iter 1: after step": 17,
            "iter 1: done": 13,
            "iter 2: start": 13,
            "iter 2: after fwd": 350,
            "iter 2: after loss": 350,
            "iter 2: after bwd": 17,
            "iter 2: after step": 17,
            "iter 2: done": 13,
            "iter 3: start": 13,
            "iter 3: after fwd": 350,
            "iter 3: after loss": 350,
            "iter 3: after bwd": 17,
            "iter 3: after step": 17,
            "iter 3: done": 13,
196
197
198
199
200
        },
        ("fsdp", "no_ckpt"): {
            "iter 0: start": 3,
            "iter 0: after fwd": 340,
            "iter 0: after loss": 340,
201
202
            "iter 0: after bwd": 16,
            "iter 0: after step": 18,
203
204
205
206
            "iter 0: done": 5,
            "iter 1: start": 5,
            "iter 1: after fwd": 342,
            "iter 1: after loss": 342,
207
208
            "iter 1: after bwd": 18,
            "iter 1: after step": 18,
209
210
211
212
            "iter 1: done": 5,
            "iter 2: start": 5,
            "iter 2: after fwd": 342,
            "iter 2: after loss": 342,
213
214
            "iter 2: after bwd": 18,
            "iter 2: after step": 18,
215
216
217
218
            "iter 2: done": 5,
            "iter 3: start": 5,
            "iter 3: after fwd": 342,
            "iter 3: after loss": 342,
219
220
            "iter 3: after bwd": 18,
            "iter 3: after step": 18,
221
222
223
224
225
226
            "iter 3: done": 5,
        },
        ("fsdp_amp_default", "no_ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 630,
            "iter 0: after loss": 630,
227
228
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
229
230
231
232
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 657,
            "iter 1: after loss": 657,
233
234
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
235
236
237
238
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 657,
            "iter 2: after loss": 657,
239
240
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
241
242
243
244
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 657,
            "iter 3: after loss": 657,
245
246
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
247
248
249
250
251
252
            "iter 3: done": 54,
        },
        ("fsdp_amp_compute_dtype32", "no_ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 657,
            "iter 0: after loss": 657,
253
254
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
255
256
257
258
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 684,
            "iter 1: after loss": 684,
259
260
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
261
262
263
264
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 684,
            "iter 2: after loss": 684,
265
266
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
267
268
269
270
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 684,
            "iter 3: after loss": 684,
271
272
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
273
            "iter 3: done": 54,
274
275
276
277
278
279
        },
        ("ddp", "ckpt"): {
            "iter 0: start": 9,
            "iter 0: after fwd": 57,
            "iter 0: after loss": 57,
            "iter 0: after bwd": 14,
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
            "iter 0: after step": 17,
            "iter 0: done": 13,
            "iter 1: start": 13,
            "iter 1: after fwd": 61,
            "iter 1: after loss": 61,
            "iter 1: after bwd": 17,
            "iter 1: after step": 17,
            "iter 1: done": 13,
            "iter 2: start": 13,
            "iter 2: after fwd": 61,
            "iter 2: after loss": 61,
            "iter 2: after bwd": 17,
            "iter 2: after step": 17,
            "iter 2: done": 13,
            "iter 3: start": 13,
            "iter 3: after fwd": 61,
            "iter 3: after loss": 61,
            "iter 3: after bwd": 17,
            "iter 3: after step": 17,
            "iter 3: done": 13,
300
301
302
303
304
        },
        ("fsdp", "ckpt"): {
            "iter 0: start": 3,
            "iter 0: after fwd": 51,
            "iter 0: after loss": 51,
305
306
            "iter 0: after bwd": 16,
            "iter 0: after step": 18,
307
308
309
310
            "iter 0: done": 5,
            "iter 1: start": 5,
            "iter 1: after fwd": 53,
            "iter 1: after loss": 53,
311
312
            "iter 1: after bwd": 18,
            "iter 1: after step": 18,
313
314
315
316
            "iter 1: done": 5,
            "iter 2: start": 5,
            "iter 2: after fwd": 53,
            "iter 2: after loss": 53,
317
318
            "iter 2: after bwd": 18,
            "iter 2: after step": 18,
319
320
321
322
            "iter 2: done": 5,
            "iter 3: start": 5,
            "iter 3: after fwd": 53,
            "iter 3: after loss": 53,
323
324
            "iter 3: after bwd": 18,
            "iter 3: after step": 18,
325
326
327
328
329
330
            "iter 3: done": 5,
        },
        ("fsdp_amp_default", "ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 52,
            "iter 0: after loss": 52,
331
332
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
333
334
335
336
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 79,
            "iter 1: after loss": 79,
337
338
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
339
340
341
342
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 79,
            "iter 2: after loss": 79,
343
344
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
345
346
347
348
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 79,
            "iter 3: after loss": 79,
349
350
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
351
352
353
354
355
356
            "iter 3: done": 54,
        },
        ("fsdp_amp_compute_dtype32", "ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 52,
            "iter 0: after loss": 52,
357
358
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
359
360
361
362
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 79,
            "iter 1: after loss": 79,
363
364
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
365
366
367
368
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 79,
            "iter 2: after loss": 79,
369
370
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
371
372
373
374
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 79,
            "iter 3: after loss": 79,
375
376
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
377
            "iter 3: done": 54,
378
379
        },
    }[(fsdp, ckpt)]
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

    # Compute the FSDP config.
    fsdp_config = {}

    # Set mixed precision.
    if "amp" in fsdp:
        fsdp_config["mixed_precision"] = True

    # When compute_dtype is FP32, make sure we use clear_autocast_cache.
    # Setting fp32_reduce_scatter and verbose for more code coverage.
    if "compute_dtype32" in fsdp:
        fsdp_config["compute_dtype"] = torch.float32
        fsdp_config["fp32_reduce_scatter"] = True
        fsdp_config["clear_autocast_cache"] = True
        fsdp_config["verbose"] = True

    # Using bigger hidden dimension for AMP to increase the model size
    # so that bug in handling params will show up but we don't do that
    # in the base case to keep the test fast.
    #   - hidden_dim 128: model size ~4MB
    #   - hidden_dim 512: model size ~55MB
    #   - hidden_dim 1024: model size ~200MB (seems to be too big for CI tests though)
    model_hidden_dim = 128
    if "amp" in fsdp:
        model_hidden_dim = 512

    # Get the fsdp and checkpoint flags.
    with_fsdp = "fsdp" in fsdp
    with_ckpt = ckpt == "ckpt"

410
411
412
    world_size = 2
    with temp_files_ctx(num=2) as temp_files:
        mp.spawn(
413
414
415
            _distributed_worker,
            (world_size, with_fsdp, with_ckpt, temp_files[0], temp_files[1], expected, model_hidden_dim, fsdp_config),
            nprocs=world_size,
416
        )