Unverified Commit 6db68518 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix]: support pytorch SyncBatchNorm under AMP & checkpointing with FSDP (#659)



* [test]: add a more general test case

- also rebalance the tests a bit

* added missing arg

* balance

* better checking

* balance

* make test smaller and faster

* make ddp results cached and enable sync_bn

* clean up

* fix tests

* changelog

* blance

* fix

* addressing comments
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent f0a40046
......@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
### Added
- FSDP: added `force\_input\_to\_fp32` flag for SyncBatchNorm
- FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633))
## [0.3.6] - 2021-04-26
......
......@@ -201,6 +201,11 @@ class FullyShardedDataParallel(nn.Module):
GPU OOM during the forward pass. Setting this flag to true will help clearing this
cache as inner FSDP instances finish part of the forward pass to save GPU memory.
Default: False
force_input_to_fp32 (bool):
Set to ``True`` to force input floating point tensors to be FP32 (if they are FP16)
when the FSDP instance is in full precision mode. This helps avoid issues of running
SyncBatchNorm with AMP and checkpoint_wrapper.
Default: False
verbose (bool):
Set this to ``True`` to turn on verbose output for model's string representation.
Default: False
......@@ -223,6 +228,7 @@ class FullyShardedDataParallel(nn.Module):
no_broadcast_optim_state: Optional[bool] = False,
state_dict_device: Optional[torch.device] = None,
clear_autocast_cache: bool = False,
force_input_to_fp32: bool = False,
verbose: bool = False,
):
init_start = time.time()
......@@ -244,6 +250,7 @@ class FullyShardedDataParallel(nn.Module):
self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device
self.clear_autocast_cache = clear_autocast_cache
self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
......@@ -561,6 +568,7 @@ class FullyShardedDataParallel(nn.Module):
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
)
return repr
......@@ -982,8 +990,16 @@ class FullyShardedDataParallel(nn.Module):
# Start of a forward pass.
self.training_state = TrainingState.FORWARD
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
if self._is_root and self.mixed_precision:
args, kwargs = cast_inputs_to_fp16(*args, **kwargs)
args, kwargs = cast_floats_to_right_precision(True, True, *args, **kwargs)
# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if self.force_input_to_fp32 and not self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(False, False, *args, **kwargs)
# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
......@@ -1676,18 +1692,25 @@ def _get_default_cuda_device(module: nn.Module) -> torch.device:
return torch.device("cuda")
@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
"""
Cast any Tensors in *args or **kwargs to FP16.
Cast floating point Tensors in *args or **kwargs to FP16 or FP32 if they are not.
"""
def fn(x: torch.Tensor) -> torch.Tensor:
def fn_fp16(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float32:
return x.half()
return x
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
def fn_fp32(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float16:
return x.float()
return x
fn = fn_fp16 if to_fp16 else fn_fp32
context = torch.no_grad() if no_grad else contextlib.suppress()
with context: # type: ignore
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
def free_storage_(data: torch.Tensor) -> None:
......@@ -1745,7 +1768,12 @@ def _pre_load_state_dict_hook(
########################################################################################
def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ProcessGroup = None) -> nn.Module:
def auto_wrap_bn(
module: nn.Module,
single_rank_pg: bool = False,
process_group: Optional[ProcessGroup] = None,
fsdp_config: Optional[Dict[str, Any]] = None,
) -> nn.Module:
"""
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
to sync BN is used and the outer FSDP is flattening.
......@@ -1762,6 +1790,10 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
single_rank_pg (bool):
If true, put BNs in a single-rank process group. Default False.
This might be needed for Apex sync BN support. Still under construction.
process_group (ProcessGroup):
Optional process group to be used.
fsdp_config (Dict):
Optional fsdp_config to be used.
Returns:
Processed module, where BNs are wrapped with a special FSDP instance.
......@@ -1786,18 +1818,22 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
else:
pg = process_group
fsdp_config = {
"wrapper_cls": FullyShardedDataParallel,
"process_group": pg,
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
# Reshard==False is good for performance. When FSDP(checkpoint(FSDP(bn))) is used, this
# **must** be False because BN's FSDP wrapper's pre-backward callback isn't called
# within the checkpoint's outer backward when multiple forward passes are used.
"reshard_after_forward": False,
# No bucketing or small bucketing should be enough for BNs.
"bucket_cap_mb": 0,
}
if fsdp_config is None:
fsdp_config = {
"wrapper_cls": FullyShardedDataParallel,
"process_group": pg,
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
# Reshard==False is good for performance. When FSDP(checkpoint(FSDP(bn))) is used, this
# **must** be False because BN's FSDP wrapper's pre-backward callback isn't called
# within the checkpoint's outer backward when multiple forward passes are used.
"reshard_after_forward": False,
# No bucketing or small bucketing should be enough for BNs.
"bucket_cap_mb": 0,
# Setting this for SyncBatchNorm. This may have a performance impact. If
# SyncBatchNorm is used, this can be enabled by passing in the `fsdp_config` argument.
"force_input_to_fp32": False,
}
with enable_wrap(wrap_bn_only_policy, **fsdp_config):
return auto_wrap(module)
tests/nn/data_parallel/test_fsdp_memory.py
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/pipe/test_parity.py
tests/experimental/nn/test_sync_batchnorm.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py
tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py
......@@ -32,11 +33,10 @@ tests/nn/pipe/test_phony.py
tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py
tests/experimental/nn/test_multiprocess_pipe.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_sync_batchnorm.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
tests/experimental/optim/test_dynamic_loss_scaler.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/test_parity.py
tests/nn/pipe/skip/test_gpipe.py
tests/nn/pipe/skip/test_verify_skippables.py
tests/nn/pipe/skip/test_stash_pop.py
......
......@@ -111,7 +111,7 @@ def _distributed_worker(
if gpu_id == 0:
print(model)
target = torch.LongTensor([0, 1]).cuda()
target = torch.tensor([0, 1], dtype=torch.long).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
......
......@@ -24,28 +24,23 @@ from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap
from fairscale.utils.testing import (
dist_init,
objects_are_equal,
skip_if_single_gpu,
teardown,
temp_files_ctx,
torch_version,
)
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version
class Model(nn.Module):
"""Model to test FSDP(checkpoint())."""
def __init__(self):
super().__init__()
self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True),)
self.block1 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=3), nn.BatchNorm2d(4), nn.ReLU(inplace=True))
self.block2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3),
nn.BatchNorm2d(128),
nn.Conv2d(4, 8, kernel_size=3),
nn.BatchNorm2d(8),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
)
self.head = nn.Linear(128, 10)
self.head = nn.Linear(8, 10)
def forward(self, x):
if isinstance(x, torch.Tensor):
......@@ -55,14 +50,50 @@ class Model(nn.Module):
return torch.cat(ys, dim=0)
def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter):
model = Model()
class Model2(nn.Module):
"""Model to test FSDP(checkpoint(), checkpoint())."""
def __init__(self):
super().__init__()
self.block1 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=3), nn.BatchNorm2d(4), nn.ReLU(inplace=True))
self.block2 = nn.Sequential(nn.Conv2d(4, 4, kernel_size=3), nn.BatchNorm2d(4), nn.ReLU(inplace=False))
self.block3 = nn.Sequential(nn.Conv2d(4, 8, kernel_size=3), nn.BatchNorm2d(8), nn.ReLU(inplace=True))
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(8, 10))
def forward(self, x):
if isinstance(x, torch.Tensor):
return self.head(self.block3(self.block2(self.block1(x))))
elif isinstance(x, list):
ys = [self.head(self.block3(self.block2(self.block1(e)))) for e in x]
return torch.cat(ys, dim=0)
def _create_model(
with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter
):
model = Model2() if with_model2 else Model()
fsdp_config = None
if with_sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
fsdp_config = {
"wrapper_cls": FSDP,
"mixed_precision": False,
"flatten_parameters": False,
"reshard_after_forward": False,
"bucket_cap_mb": 0,
"force_input_to_fp32": True, # SyncBN needs this.
}
if with_fsdp:
if wrap_bn:
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False)
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False, fsdp_config=fsdp_config)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False, fsdp_config=fsdp_config)
if with_model2:
model.block3 = auto_wrap_bn(model.block3, single_rank_pg=False, fsdp_config=fsdp_config)
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3, maintain_forward_counter=True)
with enable_wrap(
wrapper_cls=FSDP,
flatten_parameters=flatten,
......@@ -72,15 +103,29 @@ def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn,
):
model.block1 = wrap(model.block1)
model.block2 = wrap(model.block2)
if with_model2:
model.block3 = wrap(model.block3)
model.head = wrap(model.head)
else:
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3, maintain_forward_counter=False)
return model
def _distributed_worker(
gpu_id, world_size, with_fsdp, with_checkpoint, files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter
gpu_id,
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
files,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
):
filename, filename_rpc = files[:2]
filename_loss = files[2:]
......@@ -101,15 +146,17 @@ def _distributed_worker(
# Ensure we have multiple forward passes.
batch = [
torch.randn(size=(2, 3, 224, 224)).cuda(),
torch.randn(size=(2, 3, 96, 96)).cuda(),
torch.randn(size=(2, 3, 96, 96)).cuda(),
torch.randn(size=(2, 3, 16, 16)).cuda(),
torch.randn(size=(2, 3, 9, 9)).cuda(),
torch.randn(size=(2, 3, 9, 9)).cuda(),
]
if mixed_precision and not with_fsdp:
batch = [x.half() for x in batch]
model = create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter)
model = _create_model(
with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter
)
model = model.cuda()
if with_fsdp:
......@@ -135,7 +182,7 @@ def _distributed_worker(
if gpu_id == 0:
print(model)
target = torch.LongTensor([0, 1, 2, 3, 4, 5]).cuda()
target = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.long).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
......@@ -168,15 +215,91 @@ def _distributed_worker(
teardown()
_result_cache = {}
def _get_cached_results(
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
):
""" Cache the training to save time. For DDP, flatten, wrap_bn etc. doesn't matter, so
the results can be cached.
"""
if not with_fsdp:
flatten = None
wrap_bn = None
fp32_reduce_scatter = None
key = (
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
)
global _result_cache
if key not in _result_cache:
# Get 4 files: 2 for dist_init and 2 for each rank to save the losses.
with temp_files_ctx(num=2 + world_size) as temp_files:
mp.spawn(
_distributed_worker,
(
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
temp_files,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
),
nprocs=world_size,
)
final_losses = {}
for rank in range(world_size):
with open(temp_files[2 + rank], "rb") as f:
for iter_key, loss in pickle.load(f).items():
final_losses[f"rank_{rank}_{iter_key}"] = loss
_result_cache[key] = final_losses
return _result_cache[key]
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("wrap_bn", ["auto_wrap_bn", "no_auto_wrap_bn"])
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn):
@pytest.mark.parametrize("model_type", ["model1", "model2"])
@pytest.mark.parametrize("bn_type", ["bn", "sync_bn"])
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn_type):
mixed_precision = precision == "mixed"
flatten = flatten == "flatten"
wrap_bn = wrap_bn == "auto_wrap_bn"
fp32_reduce_scatter = True if mixed_precision else None
with_model2 = model_type == "model2"
with_sync_bn = bn_type == "sync_bn"
if torch_version() >= (1, 7, 0) and torch_version() < (1, 8, 0) and with_sync_bn:
# SyncBN is buggy in 1.7, errors like:
# E File "/home/circleci/venv/lib/python3.8/site-packages/torch/nn/modules/_functions.py", line 13, in forward
# E dtype=running_mean.dtype,
# E AttributeError: 'NoneType' object has no attribute 'dtype'
pytest.skip("SyncBatchNorm in 1.7 is buggy")
if with_sync_bn and not wrap_bn:
pytest.skip("SyncBatchNorm requires auto_wrap_bn")
if torch_version() < (1, 8, 0) and flatten:
# 1.6 and 1.7 throws this error:
......@@ -190,28 +313,36 @@ def test_multiple_forward_checkpoint(precision, flatten, wrap_bn):
# Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
for with_fsdp in [False, True]:
for with_checkpoint in [False, True]:
# Get 4 files: 2 for dist_init and 2 for each rank to save the losses.
with temp_files_ctx(num=2 + world_size) as temp_files:
mp.spawn(
_distributed_worker,
(
world_size,
with_fsdp,
with_checkpoint,
temp_files,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
),
nprocs=world_size,
)
final_losses = {}
for rank in range(world_size):
with open(temp_files[2 + rank], "rb") as f:
final_losses[f"rank_{rank}"] = pickle.load(f)
if expected_losses is None:
expected_losses = final_losses
else:
print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}")
assert objects_are_equal(expected_losses, final_losses, raise_exception=True)
if not with_fsdp and with_checkpoint:
continue
final_losses = _get_cached_results(
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
)
if expected_losses is None:
expected_losses = final_losses
else:
print(f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} with ddp+no_ckpt")
def check(exp, res):
assert list(exp.keys()) == list(res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}"
rtol = 1e-4
atol = 1e-5
if with_model2 and mixed_precision and torch_version() >= (1, 9, 0):
# On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
# larger errors.
rtol = 1e-3
atol = 1e-4
for key in exp.keys():
exp_loss = exp[key]
res_loss = res[key]
torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol)
check(expected_losses, final_losses)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment