Unverified Commit 8b59267b authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] FSDP: add auto_wrap_bn (#531)

* [feat] FSDP: add auto_wrap_bn

- add an utility function to handle wrapping of BN

* changelog
parent 2fc1f6d8
...@@ -8,7 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -8,7 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372)) - Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372))
- FSDP: enabling pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527)) - FSDP: enabled pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527))
- FSDP: added auto\_wrap\_bn utility function ([#531](https://github.com/facebookresearch/fairscale/pull/531))
### Fixed ### Fixed
- OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510)) - OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510))
......
...@@ -3,5 +3,5 @@ ...@@ -3,5 +3,5 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState, auto_wrap_bn
from .sharded_ddp import ShardedDataParallel from .sharded_ddp import ShardedDataParallel
...@@ -20,6 +20,7 @@ from torch.nn import Parameter ...@@ -20,6 +20,7 @@ from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import calc_grad_norm from fairscale.optim.utils import calc_grad_norm
from fairscale.utils.containers import apply_to_tensors from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
...@@ -1337,3 +1338,46 @@ def _pre_load_state_dict_hook( ...@@ -1337,3 +1338,46 @@ def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None: ) -> None:
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.") replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
########################################################################################
# Below are APIs used together with FSDP, but not directly part of FSDP.
########################################################################################
def auto_wrap_bn(module: nn.Module) -> nn.Module:
"""
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
to sync BN is used and the outer FSDP is flattening.
We put BN in is own full precision, unflatten, single GPU group FSDP. Note, SyncBNs still have
a group size == world_size. The input and output for BN are still FP16 in mixed precision mode.
See ``keep_batchnorm_fp32`` here: https://nvidia.github.io/apex/amp.html
This needs to be done at each rank, like models being wrapped by FSDP at each rank.
Args:
module (nn.Module):
The model (or part of the model) in which BN to be pre-wrapped.
Returns:
Processed module, where BNs are wrapped with a special FSDP instance.
"""
def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) -> bool:
is_bn = isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
if recurse:
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore
else:
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore
my_rank = dist.get_rank()
fsdp_config = {
"wrapper_cls": FullyShardedDataParallel,
"process_group": dist.new_group(ranks=[my_rank]), # No sharding with this single member group.
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
}
with enable_wrap(wrap_bn_only_policy, **fsdp_config):
return auto_wrap(module)
...@@ -9,17 +9,17 @@ ...@@ -9,17 +9,17 @@
""" Test FSDP with regnet-like model. """ """ Test FSDP with regnet-like model. """
import random
import tempfile import tempfile
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm
from torch.optim import SGD from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version
...@@ -35,16 +35,7 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): ...@@ -35,16 +35,7 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
# TODO (Min): for now, we just test pytorch sync_bn here. # TODO (Min): for now, we just test pytorch sync_bn here.
# this will grow into regnet; testing apex sync_bn, etc. # this will grow into regnet; testing apex sync_bn, etc.
self.conv = Conv2d(2, 2, (1, 1)) self.conv = Conv2d(2, 2, (1, 1))
# Put BN in is own FP32, unflatten, single GPU group FSDP. self.bn = BatchNorm2d(2)
# Note, SyncBNs still have a group size == world_size.
# The input and output for BN are still FP16. See ``keep_batchnorm_fp32``
# here: https://nvidia.github.io/apex/amp.html
self.bn = FSDP(
BatchNorm2d(2),
mixed_precision=False,
process_group=dist.new_group(ranks=[rank]),
flatten_parameters=False,
)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
...@@ -54,7 +45,16 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): ...@@ -54,7 +45,16 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
# TODO (Min): check DDP equivalency. # TODO (Min): check DDP equivalency.
model = Model() model = Model()
# Note, different rank may wrap in different order due to different random
# seeds. But results should be the same.
if random.randint(0, 1) == 0:
print("auto_wrap_bn, then convert_sync_batchnorm")
model = auto_wrap_bn(model)
model = SyncBatchNorm.convert_sync_batchnorm(model)
else:
print("convert_sync_batchnorm, then auto_wrap_bn")
model = SyncBatchNorm.convert_sync_batchnorm(model) model = SyncBatchNorm.convert_sync_batchnorm(model)
model = auto_wrap_bn(model)
model = FSDP(model, **fsdp_config).cuda() model = FSDP(model, **fsdp_config).cuda()
optim = SGD(model.parameters(), lr=0.1) optim = SGD(model.parameters(), lr=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment