Unverified Commit 2fc1f6d8 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feature] FSDP: enable pytorch SyncBN (#527)

* [feature] FSDP: enable pytorch SyncBN

- not fully validated yet but at least not asserting
- this enables VISSL to move forward with its next PR

* add the test file

* changelog and lint

* addressed comment
parent 142cfdcc
...@@ -8,10 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -8,10 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372)) - Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372))
- FSDP: enabling pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527))
### Fixed ### Fixed
- OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510)) - OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510))
- FSDP: fixed a bug when part of autograd graph is traversed multiple times in mixed precision mode ([#513](https://github.com/facebookresearch/fairscale/pull/513()) - FSDP: fixed a bug when part of autograd graph is traversed multiple times in mixed precision mode ([#513](https://github.com/facebookresearch/fairscale/pull/513))
## [0.3.1] - 2021-03-09 ## [0.3.1] - 2021-03-09
### Added ### Added
......
...@@ -22,7 +22,7 @@ import torch.nn.functional as F ...@@ -22,7 +22,7 @@ import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.optim.utils import calc_grad_norm from fairscale.optim.utils import calc_grad_norm
from fairscale.utils.containers import apply_to_tensors from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, validate_process_group from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.utils.state_dict import replace_by_prefix_
...@@ -186,6 +186,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -186,6 +186,7 @@ class FullyShardedDataParallel(nn.Module):
compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device
validate_process_group(compute_device, self.process_group) validate_process_group(compute_device, self.process_group)
enable_pytorch_sync_bn(module)
# Only handle params which are not already sharded. This enables # Only handle params which are not already sharded. This enables
# sharding individual layers of a Module, with an outer wrapper to # sharding individual layers of a Module, with an outer wrapper to
......
...@@ -43,3 +43,15 @@ def validate_process_group(device: torch.device, process_group: ProcessGroup) -> ...@@ -43,3 +43,15 @@ def validate_process_group(device: torch.device, process_group: ProcessGroup) ->
f"found {torch.cat(output).sum()} devices in process group but " f"found {torch.cat(output).sum()} devices in process group but "
f"world_size={world_size}. Check torch.cuda.set_device is called properly" f"world_size={world_size}. Check torch.cuda.set_device is called properly"
) )
def enable_pytorch_sync_bn(module: torch.nn.Module) -> None:
"""Call _specify_ddp_gpu_num for all pytorch SyncBN layers so that it
is happily running even without DDP. E.g. this is used by FSDP.
"""
for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm):
# Number "1" below meant to be the number of GPUs for each DDP worker.
# (i.e. "device_ids" in DDP. As far as I see, the value is not actually
# used, but this call needs to be made to avoid an exception.
layer._specify_ddp_gpu_num(1) # type: ignore
...@@ -4,6 +4,7 @@ tests/nn/data_parallel/test_fsdp_no_sync.py ...@@ -4,6 +4,7 @@ tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_sharded_ddp_features.py tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py tests/nn/pipe/skip/test_gpipe.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with regnet-like model. """
import tempfile
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm
from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
assert isinstance(fsdp_config, dict), str(fsdp_config)
class Model(Module):
def __init__(self):
super().__init__()
# TODO (Min): for now, we just test pytorch sync_bn here.
# this will grow into regnet; testing apex sync_bn, etc.
self.conv = Conv2d(2, 2, (1, 1))
# Put BN in is own FP32, unflatten, single GPU group FSDP.
# Note, SyncBNs still have a group size == world_size.
# The input and output for BN are still FP16. See ``keep_batchnorm_fp32``
# here: https://nvidia.github.io/apex/amp.html
self.bn = FSDP(
BatchNorm2d(2),
mixed_precision=False,
process_group=dist.new_group(ranks=[rank]),
flatten_parameters=False,
)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
# TODO (Min): check DDP equivalency.
model = Model()
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = FSDP(model, **fsdp_config).cuda()
optim = SGD(model.parameters(), lr=0.1)
for _ in range(3):
in_data = torch.rand(2, 2, 2, 2).cuda()
in_data.requires_grad = True
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
model.assert_state(TrainingState.IDLE)
teardown()
# We use strings for precision and flatten instead of bool to
# make the pytest output more readable.
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
def test1(precision, flatten):
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
fsdp_config = {}
fsdp_config["mixed_precision"] = precision == "mixed"
fsdp_config["flatten_parameters"] = flatten == "flatten"
# Some bugs only show up when we are in world_size > 1 due to sharding changing
# the tensor dimensions.
world_size = 2
mp.spawn(
_test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment