# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # pylint: disable=missing-module-docstring # pylint: disable=missing-class-docstring # pylint: disable=missing-function-docstring """ Test FSDP with GPU memory usage. """ import contextlib import pytest import torch import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel import torch.optim as optim from fair_dev.testing.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx from fairscale.internal import torch_version from fairscale.internal.parallel import get_process_group_cached from fairscale.nn import checkpoint_wrapper from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import auto_wrap_bn def to_fsdp(module, fsdp_config): return FSDP(module, process_group=get_process_group_cached(), **fsdp_config) def get_cur_mem(rank, result, prefix): """Collect memory allocated values in a result dict in MB""" result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024) class Model(nn.Module): def __init__(self, hidden_dim): super().__init__() # TODO (Min): for both fast and memory efficient conv kernels, we should be using # AMP/fp16 + channel_last input format. Otherwise, cudnn internally does conversion # to channel_last when it is fp16 weights. Leave this knowledge here and perhaps # future test can cover it. self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) self.blocks = nn.Sequential( nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), ) self.head = nn.Linear(hidden_dim, 10) def forward(self, x): return self.head(self.blocks(self.stem(x))) def create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config): model = Model(model_hidden_dim) if with_fsdp: model.stem = auto_wrap_bn(model.stem, single_rank_pg=False) model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False) if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) model.stem = to_fsdp(model.stem, fsdp_config) model.blocks = to_fsdp(model.blocks, fsdp_config) model.head = to_fsdp(model.head, fsdp_config) else: if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) return model def _distributed_worker( gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected, model_hidden_dim, fsdp_config ): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True # Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here. batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config) model = model.cuda() if with_fsdp: model = to_fsdp(model, fsdp_config) else: model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500) # We enable momentum so that after the first iteration, the optimizer state is added # to the total memory used. criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) # Set AMP context if needed. context = contextlib.suppress() if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]: context = torch.cuda.amp.autocast(enabled=True) # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this # test but on much bigger scale tests). We run 4 iterations here just in case it happens. iterations = 4 results = {} # results of memory stats for iteration in range(iterations): get_cur_mem(gpu_id, results, f"iter {iteration}: start") with context: out = model(batch) get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) fake_loss = criterion(out, torch.tensor(0.0).cuda()) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None get_cur_mem(gpu_id, results, f"iter {iteration}: done") dump_all_tensors(gpu_id) print(results) def cmp(results, expected): ret = "" assert results.keys() == expected.keys(), f"{list(results.keys())} vs. {list(expected.keys())}" for k, v in results.items(): exp = expected[k] if abs(exp - v) > 1: # allow 1MB rounding differences ret += f"{k}: got {v}, expected {exp}\n" return ret output = cmp(results, expected) assert not output, output teardown() @skip_if_single_gpu @pytest.mark.timeout(120) @pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"]) @pytest.mark.parametrize("fsdp", ["ddp", "fsdp", "fsdp_amp_default", "fsdp_amp_compute_dtype32"]) def test_fsdp_memory(fsdp, ckpt): expected = { ("ddp", "no_ckpt"): { "iter 0: start": 9, "iter 0: after fwd": 346, "iter 0: after loss": 346, "iter 0: after bwd": 14, "iter 0: after step": 17, "iter 0: done": 13, "iter 1: start": 13, "iter 1: after fwd": 350, "iter 1: after loss": 350, "iter 1: after bwd": 17, "iter 1: after step": 17, "iter 1: done": 13, "iter 2: start": 13, "iter 2: after fwd": 350, "iter 2: after loss": 350, "iter 2: after bwd": 17, "iter 2: after step": 17, "iter 2: done": 13, "iter 3: start": 13, "iter 3: after fwd": 350, "iter 3: after loss": 350, "iter 3: after bwd": 17, "iter 3: after step": 17, "iter 3: done": 13, }, ("fsdp", "no_ckpt"): { "iter 0: start": 3, "iter 0: after fwd": 340, "iter 0: after loss": 340, "iter 0: after bwd": 16, "iter 0: after step": 18, "iter 0: done": 5, "iter 1: start": 5, "iter 1: after fwd": 342, "iter 1: after loss": 342, "iter 1: after bwd": 18, "iter 1: after step": 18, "iter 1: done": 5, "iter 2: start": 5, "iter 2: after fwd": 342, "iter 2: after loss": 342, "iter 2: after bwd": 18, "iter 2: after step": 18, "iter 2: done": 5, "iter 3: start": 5, "iter 3: after fwd": 342, "iter 3: after loss": 342, "iter 3: after bwd": 18, "iter 3: after step": 18, "iter 3: done": 5, }, ("fsdp_amp_default", "no_ckpt"): { "iter 0: start": 28, "iter 0: after fwd": 630, "iter 0: after loss": 630, "iter 0: after bwd": 67, "iter 0: after step": 93, "iter 0: done": 54, "iter 1: start": 54, "iter 1: after fwd": 657, "iter 1: after loss": 657, "iter 1: after bwd": 93, "iter 1: after step": 93, "iter 1: done": 54, "iter 2: start": 54, "iter 2: after fwd": 657, "iter 2: after loss": 657, "iter 2: after bwd": 93, "iter 2: after step": 93, "iter 2: done": 54, "iter 3: start": 54, "iter 3: after fwd": 657, "iter 3: after loss": 657, "iter 3: after bwd": 93, "iter 3: after step": 93, "iter 3: done": 54, }, ("fsdp_amp_compute_dtype32", "no_ckpt"): { "iter 0: start": 28, "iter 0: after fwd": 657, "iter 0: after loss": 657, "iter 0: after bwd": 67, "iter 0: after step": 93, "iter 0: done": 54, "iter 1: start": 54, "iter 1: after fwd": 684, "iter 1: after loss": 684, "iter 1: after bwd": 93, "iter 1: after step": 93, "iter 1: done": 54, "iter 2: start": 54, "iter 2: after fwd": 684, "iter 2: after loss": 684, "iter 2: after bwd": 93, "iter 2: after step": 93, "iter 2: done": 54, "iter 3: start": 54, "iter 3: after fwd": 684, "iter 3: after loss": 684, "iter 3: after bwd": 93, "iter 3: after step": 93, "iter 3: done": 54, }, ("ddp", "ckpt"): { "iter 0: start": 9, "iter 0: after fwd": 57, "iter 0: after loss": 57, "iter 0: after bwd": 14, "iter 0: after step": 17, "iter 0: done": 13, "iter 1: start": 13, "iter 1: after fwd": 61, "iter 1: after loss": 61, "iter 1: after bwd": 17, "iter 1: after step": 17, "iter 1: done": 13, "iter 2: start": 13, "iter 2: after fwd": 61, "iter 2: after loss": 61, "iter 2: after bwd": 17, "iter 2: after step": 17, "iter 2: done": 13, "iter 3: start": 13, "iter 3: after fwd": 61, "iter 3: after loss": 61, "iter 3: after bwd": 17, "iter 3: after step": 17, "iter 3: done": 13, }, ("fsdp", "ckpt"): { "iter 0: start": 3, "iter 0: after fwd": 51, "iter 0: after loss": 51, "iter 0: after bwd": 16, "iter 0: after step": 18, "iter 0: done": 5, "iter 1: start": 5, "iter 1: after fwd": 53, "iter 1: after loss": 53, "iter 1: after bwd": 18, "iter 1: after step": 18, "iter 1: done": 5, "iter 2: start": 5, "iter 2: after fwd": 53, "iter 2: after loss": 53, "iter 2: after bwd": 18, "iter 2: after step": 18, "iter 2: done": 5, "iter 3: start": 5, "iter 3: after fwd": 53, "iter 3: after loss": 53, "iter 3: after bwd": 18, "iter 3: after step": 18, "iter 3: done": 5, }, ("fsdp_amp_default", "ckpt"): { "iter 0: start": 28, "iter 0: after fwd": 52, "iter 0: after loss": 52, "iter 0: after bwd": 67, "iter 0: after step": 93, "iter 0: done": 54, "iter 1: start": 54, "iter 1: after fwd": 79, "iter 1: after loss": 79, "iter 1: after bwd": 93, "iter 1: after step": 93, "iter 1: done": 54, "iter 2: start": 54, "iter 2: after fwd": 79, "iter 2: after loss": 79, "iter 2: after bwd": 93, "iter 2: after step": 93, "iter 2: done": 54, "iter 3: start": 54, "iter 3: after fwd": 79, "iter 3: after loss": 79, "iter 3: after bwd": 93, "iter 3: after step": 93, "iter 3: done": 54, }, ("fsdp_amp_compute_dtype32", "ckpt"): { "iter 0: start": 28, "iter 0: after fwd": 52, "iter 0: after loss": 52, "iter 0: after bwd": 67, "iter 0: after step": 93, "iter 0: done": 54, "iter 1: start": 54, "iter 1: after fwd": 79, "iter 1: after loss": 79, "iter 1: after bwd": 93, "iter 1: after step": 93, "iter 1: done": 54, "iter 2: start": 54, "iter 2: after fwd": 79, "iter 2: after loss": 79, "iter 2: after bwd": 93, "iter 2: after step": 93, "iter 2: done": 54, "iter 3: start": 54, "iter 3: after fwd": 79, "iter 3: after loss": 79, "iter 3: after bwd": 93, "iter 3: after step": 93, "iter 3: done": 54, }, }[(fsdp, ckpt)] # Compute the FSDP config. fsdp_config = {} # Set mixed precision. if "amp" in fsdp: fsdp_config["mixed_precision"] = True # When compute_dtype is FP32, make sure we use clear_autocast_cache. # Setting fp32_reduce_scatter and verbose for more code coverage. if "compute_dtype32" in fsdp: fsdp_config["compute_dtype"] = torch.float32 fsdp_config["fp32_reduce_scatter"] = True fsdp_config["clear_autocast_cache"] = True fsdp_config["verbose"] = True # Using bigger hidden dimension for AMP to increase the model size # so that bug in handling params will show up but we don't do that # in the base case to keep the test fast. # - hidden_dim 128: model size ~4MB # - hidden_dim 512: model size ~55MB # - hidden_dim 1024: model size ~200MB (seems to be too big for CI tests though) model_hidden_dim = 128 if "amp" in fsdp: model_hidden_dim = 512 # Get the fsdp and checkpoint flags. with_fsdp = "fsdp" in fsdp with_ckpt = ckpt == "ckpt" world_size = 2 with temp_files_ctx(num=2) as temp_files: mp.spawn( _distributed_worker, (world_size, with_fsdp, with_ckpt, temp_files[0], temp_files[1], expected, model_hidden_dim, fsdp_config), nprocs=world_size, )