Unverified Commit 6db68518 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix]: support pytorch SyncBatchNorm under AMP & checkpointing with FSDP (#659)



* [test]: add a more general test case

- also rebalance the tests a bit

* added missing arg

* balance

* better checking

* balance

* make test smaller and faster

* make ddp results cached and enable sync_bn

* clean up

* fix tests

* changelog

* blance

* fix

* addressing comments
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent f0a40046
...@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647)) - SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
### Added ### Added
- FSDP: added `force\_input\_to\_fp32` flag for SyncBatchNorm
- FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633)) - FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633))
## [0.3.6] - 2021-04-26 ## [0.3.6] - 2021-04-26
......
...@@ -201,6 +201,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -201,6 +201,11 @@ class FullyShardedDataParallel(nn.Module):
GPU OOM during the forward pass. Setting this flag to true will help clearing this GPU OOM during the forward pass. Setting this flag to true will help clearing this
cache as inner FSDP instances finish part of the forward pass to save GPU memory. cache as inner FSDP instances finish part of the forward pass to save GPU memory.
Default: False Default: False
force_input_to_fp32 (bool):
Set to ``True`` to force input floating point tensors to be FP32 (if they are FP16)
when the FSDP instance is in full precision mode. This helps avoid issues of running
SyncBatchNorm with AMP and checkpoint_wrapper.
Default: False
verbose (bool): verbose (bool):
Set this to ``True`` to turn on verbose output for model's string representation. Set this to ``True`` to turn on verbose output for model's string representation.
Default: False Default: False
...@@ -223,6 +228,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -223,6 +228,7 @@ class FullyShardedDataParallel(nn.Module):
no_broadcast_optim_state: Optional[bool] = False, no_broadcast_optim_state: Optional[bool] = False,
state_dict_device: Optional[torch.device] = None, state_dict_device: Optional[torch.device] = None,
clear_autocast_cache: bool = False, clear_autocast_cache: bool = False,
force_input_to_fp32: bool = False,
verbose: bool = False, verbose: bool = False,
): ):
init_start = time.time() init_start = time.time()
...@@ -244,6 +250,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -244,6 +250,7 @@ class FullyShardedDataParallel(nn.Module):
self.no_broadcast_optim_state = no_broadcast_optim_state self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device self.state_dict_device = state_dict_device or self.compute_device
self.clear_autocast_cache = clear_autocast_cache self.clear_autocast_cache = clear_autocast_cache
self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose self.verbose = verbose
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size) self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
...@@ -561,6 +568,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -561,6 +568,7 @@ class FullyShardedDataParallel(nn.Module):
f"move_grads_to_cpu={self.move_grads_to_cpu}, " f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, " f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}" f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
) )
return repr return repr
...@@ -982,8 +990,16 @@ class FullyShardedDataParallel(nn.Module): ...@@ -982,8 +990,16 @@ class FullyShardedDataParallel(nn.Module):
# Start of a forward pass. # Start of a forward pass.
self.training_state = TrainingState.FORWARD self.training_state = TrainingState.FORWARD
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
if self._is_root and self.mixed_precision: if self._is_root and self.mixed_precision:
args, kwargs = cast_inputs_to_fp16(*args, **kwargs) args, kwargs = cast_floats_to_right_precision(True, True, *args, **kwargs)
# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if self.force_input_to_fp32 and not self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(False, False, *args, **kwargs)
# All-gather full parameters. This will also transfer FP32 parameters to # All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
...@@ -1676,17 +1692,24 @@ def _get_default_cuda_device(module: nn.Module) -> torch.device: ...@@ -1676,17 +1692,24 @@ def _get_default_cuda_device(module: nn.Module) -> torch.device:
return torch.device("cuda") return torch.device("cuda")
@torch.no_grad() def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
""" """
Cast any Tensors in *args or **kwargs to FP16. Cast floating point Tensors in *args or **kwargs to FP16 or FP32 if they are not.
""" """
def fn(x: torch.Tensor) -> torch.Tensor: def fn_fp16(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float32: if x.dtype is torch.float32:
return x.half() return x.half()
return x return x
def fn_fp32(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float16:
return x.float()
return x
fn = fn_fp16 if to_fp16 else fn_fp32
context = torch.no_grad() if no_grad else contextlib.suppress()
with context: # type: ignore
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs) return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
...@@ -1745,7 +1768,12 @@ def _pre_load_state_dict_hook( ...@@ -1745,7 +1768,12 @@ def _pre_load_state_dict_hook(
######################################################################################## ########################################################################################
def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ProcessGroup = None) -> nn.Module: def auto_wrap_bn(
module: nn.Module,
single_rank_pg: bool = False,
process_group: Optional[ProcessGroup] = None,
fsdp_config: Optional[Dict[str, Any]] = None,
) -> nn.Module:
""" """
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
to sync BN is used and the outer FSDP is flattening. to sync BN is used and the outer FSDP is flattening.
...@@ -1762,6 +1790,10 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ...@@ -1762,6 +1790,10 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
single_rank_pg (bool): single_rank_pg (bool):
If true, put BNs in a single-rank process group. Default False. If true, put BNs in a single-rank process group. Default False.
This might be needed for Apex sync BN support. Still under construction. This might be needed for Apex sync BN support. Still under construction.
process_group (ProcessGroup):
Optional process group to be used.
fsdp_config (Dict):
Optional fsdp_config to be used.
Returns: Returns:
Processed module, where BNs are wrapped with a special FSDP instance. Processed module, where BNs are wrapped with a special FSDP instance.
...@@ -1786,6 +1818,7 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ...@@ -1786,6 +1818,7 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
else: else:
pg = process_group pg = process_group
if fsdp_config is None:
fsdp_config = { fsdp_config = {
"wrapper_cls": FullyShardedDataParallel, "wrapper_cls": FullyShardedDataParallel,
"process_group": pg, "process_group": pg,
...@@ -1797,6 +1830,9 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ...@@ -1797,6 +1830,9 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
"reshard_after_forward": False, "reshard_after_forward": False,
# No bucketing or small bucketing should be enough for BNs. # No bucketing or small bucketing should be enough for BNs.
"bucket_cap_mb": 0, "bucket_cap_mb": 0,
# Setting this for SyncBatchNorm. This may have a performance impact. If
# SyncBatchNorm is used, this can be enabled by passing in the `fsdp_config` argument.
"force_input_to_fp32": False,
} }
with enable_wrap(wrap_bn_only_policy, **fsdp_config): with enable_wrap(wrap_bn_only_policy, **fsdp_config):
......
tests/nn/data_parallel/test_fsdp_memory.py tests/nn/data_parallel/test_fsdp_memory.py
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp.py tests/nn/data_parallel/test_fsdp.py
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/pipe/test_parity.py
tests/experimental/nn/test_sync_batchnorm.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/utils/test_reduce_scatter_bucketer.py tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py tests/utils/test_containers.py
tests/utils/test_parallel.py tests/utils/test_parallel.py
tests/utils/test_state_dict.py tests/utils/test_state_dict.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py tests/nn/pipe_process/test_transparency.py
...@@ -32,11 +33,10 @@ tests/nn/pipe/test_phony.py ...@@ -32,11 +33,10 @@ tests/nn/pipe/test_phony.py
tests/nn/pipe/test_deferred_batch_norm.py tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py tests/nn/pipe/test_stream.py
tests/experimental/nn/test_multiprocess_pipe.py
tests/nn/moe/test_moe_layer.py tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py tests/nn/moe/test_top2gating.py
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_sync_batchnorm.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py tests/experimental/nn/test_offload.py
tests/experimental/optim/test_dynamic_loss_scaler.py tests/experimental/optim/test_dynamic_loss_scaler.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/nn/data_parallel/test_fsdp_regnet.py tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/data_parallel/test_fsdp_uneven.py tests/nn/misc/test_grad_bucket.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py tests/nn/misc/test_param_bucket.py
tests/nn/data_parallel/test_fsdp_no_sync.py tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_sharded_ddp_features.py tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/test_parity.py
tests/nn/pipe/skip/test_gpipe.py tests/nn/pipe/skip/test_gpipe.py
tests/nn/pipe/skip/test_verify_skippables.py tests/nn/pipe/skip/test_verify_skippables.py
tests/nn/pipe/skip/test_stash_pop.py tests/nn/pipe/skip/test_stash_pop.py
......
...@@ -111,7 +111,7 @@ def _distributed_worker( ...@@ -111,7 +111,7 @@ def _distributed_worker(
if gpu_id == 0: if gpu_id == 0:
print(model) print(model)
target = torch.LongTensor([0, 1]).cuda() target = torch.tensor([0, 1], dtype=torch.long).cuda()
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
......
...@@ -24,28 +24,23 @@ from fairscale.nn import checkpoint_wrapper ...@@ -24,28 +24,23 @@ from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap from fairscale.nn.wrap import enable_wrap, wrap
from fairscale.utils.testing import ( from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version
dist_init,
objects_are_equal,
skip_if_single_gpu,
teardown,
temp_files_ctx,
torch_version,
)
class Model(nn.Module): class Model(nn.Module):
"""Model to test FSDP(checkpoint())."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True),) self.block1 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=3), nn.BatchNorm2d(4), nn.ReLU(inplace=True))
self.block2 = nn.Sequential( self.block2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3), nn.Conv2d(4, 8, kernel_size=3),
nn.BatchNorm2d(128), nn.BatchNorm2d(8),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(), nn.Flatten(),
) )
self.head = nn.Linear(128, 10) self.head = nn.Linear(8, 10)
def forward(self, x): def forward(self, x):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
...@@ -55,14 +50,50 @@ class Model(nn.Module): ...@@ -55,14 +50,50 @@ class Model(nn.Module):
return torch.cat(ys, dim=0) return torch.cat(ys, dim=0)
def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter): class Model2(nn.Module):
model = Model() """Model to test FSDP(checkpoint(), checkpoint())."""
def __init__(self):
super().__init__()
self.block1 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=3), nn.BatchNorm2d(4), nn.ReLU(inplace=True))
self.block2 = nn.Sequential(nn.Conv2d(4, 4, kernel_size=3), nn.BatchNorm2d(4), nn.ReLU(inplace=False))
self.block3 = nn.Sequential(nn.Conv2d(4, 8, kernel_size=3), nn.BatchNorm2d(8), nn.ReLU(inplace=True))
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(8, 10))
def forward(self, x):
if isinstance(x, torch.Tensor):
return self.head(self.block3(self.block2(self.block1(x))))
elif isinstance(x, list):
ys = [self.head(self.block3(self.block2(self.block1(e)))) for e in x]
return torch.cat(ys, dim=0)
def _create_model(
with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter
):
model = Model2() if with_model2 else Model()
fsdp_config = None
if with_sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
fsdp_config = {
"wrapper_cls": FSDP,
"mixed_precision": False,
"flatten_parameters": False,
"reshard_after_forward": False,
"bucket_cap_mb": 0,
"force_input_to_fp32": True, # SyncBN needs this.
}
if with_fsdp: if with_fsdp:
if wrap_bn: if wrap_bn:
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False) model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False, fsdp_config=fsdp_config)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False) model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False, fsdp_config=fsdp_config)
if with_model2:
model.block3 = auto_wrap_bn(model.block3, single_rank_pg=False, fsdp_config=fsdp_config)
if with_checkpoint: if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True) model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3, maintain_forward_counter=True)
with enable_wrap( with enable_wrap(
wrapper_cls=FSDP, wrapper_cls=FSDP,
flatten_parameters=flatten, flatten_parameters=flatten,
...@@ -72,15 +103,29 @@ def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, ...@@ -72,15 +103,29 @@ def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn,
): ):
model.block1 = wrap(model.block1) model.block1 = wrap(model.block1)
model.block2 = wrap(model.block2) model.block2 = wrap(model.block2)
if with_model2:
model.block3 = wrap(model.block3)
model.head = wrap(model.head) model.head = wrap(model.head)
else: else:
if with_checkpoint: if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False) model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3, maintain_forward_counter=False)
return model return model
def _distributed_worker( def _distributed_worker(
gpu_id, world_size, with_fsdp, with_checkpoint, files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter gpu_id,
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
files,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
): ):
filename, filename_rpc = files[:2] filename, filename_rpc = files[:2]
filename_loss = files[2:] filename_loss = files[2:]
...@@ -101,15 +146,17 @@ def _distributed_worker( ...@@ -101,15 +146,17 @@ def _distributed_worker(
# Ensure we have multiple forward passes. # Ensure we have multiple forward passes.
batch = [ batch = [
torch.randn(size=(2, 3, 224, 224)).cuda(), torch.randn(size=(2, 3, 16, 16)).cuda(),
torch.randn(size=(2, 3, 96, 96)).cuda(), torch.randn(size=(2, 3, 9, 9)).cuda(),
torch.randn(size=(2, 3, 96, 96)).cuda(), torch.randn(size=(2, 3, 9, 9)).cuda(),
] ]
if mixed_precision and not with_fsdp: if mixed_precision and not with_fsdp:
batch = [x.half() for x in batch] batch = [x.half() for x in batch]
model = create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter) model = _create_model(
with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter
)
model = model.cuda() model = model.cuda()
if with_fsdp: if with_fsdp:
...@@ -135,7 +182,7 @@ def _distributed_worker( ...@@ -135,7 +182,7 @@ def _distributed_worker(
if gpu_id == 0: if gpu_id == 0:
print(model) print(model)
target = torch.LongTensor([0, 1, 2, 3, 4, 5]).cuda() target = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.long).cuda()
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
...@@ -168,15 +215,91 @@ def _distributed_worker( ...@@ -168,15 +215,91 @@ def _distributed_worker(
teardown() teardown()
_result_cache = {}
def _get_cached_results(
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
):
""" Cache the training to save time. For DDP, flatten, wrap_bn etc. doesn't matter, so
the results can be cached.
"""
if not with_fsdp:
flatten = None
wrap_bn = None
fp32_reduce_scatter = None
key = (
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
)
global _result_cache
if key not in _result_cache:
# Get 4 files: 2 for dist_init and 2 for each rank to save the losses.
with temp_files_ctx(num=2 + world_size) as temp_files:
mp.spawn(
_distributed_worker,
(
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
temp_files,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
),
nprocs=world_size,
)
final_losses = {}
for rank in range(world_size):
with open(temp_files[2 + rank], "rb") as f:
for iter_key, loss in pickle.load(f).items():
final_losses[f"rank_{rank}_{iter_key}"] = loss
_result_cache[key] = final_losses
return _result_cache[key]
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"]) @pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"]) @pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("wrap_bn", ["auto_wrap_bn", "no_auto_wrap_bn"]) @pytest.mark.parametrize("wrap_bn", ["auto_wrap_bn", "no_auto_wrap_bn"])
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn): @pytest.mark.parametrize("model_type", ["model1", "model2"])
@pytest.mark.parametrize("bn_type", ["bn", "sync_bn"])
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn_type):
mixed_precision = precision == "mixed" mixed_precision = precision == "mixed"
flatten = flatten == "flatten" flatten = flatten == "flatten"
wrap_bn = wrap_bn == "auto_wrap_bn" wrap_bn = wrap_bn == "auto_wrap_bn"
fp32_reduce_scatter = True if mixed_precision else None fp32_reduce_scatter = True if mixed_precision else None
with_model2 = model_type == "model2"
with_sync_bn = bn_type == "sync_bn"
if torch_version() >= (1, 7, 0) and torch_version() < (1, 8, 0) and with_sync_bn:
# SyncBN is buggy in 1.7, errors like:
# E File "/home/circleci/venv/lib/python3.8/site-packages/torch/nn/modules/_functions.py", line 13, in forward
# E dtype=running_mean.dtype,
# E AttributeError: 'NoneType' object has no attribute 'dtype'
pytest.skip("SyncBatchNorm in 1.7 is buggy")
if with_sync_bn and not wrap_bn:
pytest.skip("SyncBatchNorm requires auto_wrap_bn")
if torch_version() < (1, 8, 0) and flatten: if torch_version() < (1, 8, 0) and flatten:
# 1.6 and 1.7 throws this error: # 1.6 and 1.7 throws this error:
...@@ -190,28 +313,36 @@ def test_multiple_forward_checkpoint(precision, flatten, wrap_bn): ...@@ -190,28 +313,36 @@ def test_multiple_forward_checkpoint(precision, flatten, wrap_bn):
# Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt. # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
for with_fsdp in [False, True]: for with_fsdp in [False, True]:
for with_checkpoint in [False, True]: for with_checkpoint in [False, True]:
# Get 4 files: 2 for dist_init and 2 for each rank to save the losses. if not with_fsdp and with_checkpoint:
with temp_files_ctx(num=2 + world_size) as temp_files: continue
mp.spawn( final_losses = _get_cached_results(
_distributed_worker,
(
world_size, world_size,
with_model2,
with_sync_bn,
with_fsdp, with_fsdp,
with_checkpoint, with_checkpoint,
temp_files,
mixed_precision, mixed_precision,
flatten, flatten,
wrap_bn, wrap_bn,
fp32_reduce_scatter, fp32_reduce_scatter,
),
nprocs=world_size,
) )
final_losses = {}
for rank in range(world_size):
with open(temp_files[2 + rank], "rb") as f:
final_losses[f"rank_{rank}"] = pickle.load(f)
if expected_losses is None: if expected_losses is None:
expected_losses = final_losses expected_losses = final_losses
else: else:
print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}") print(f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} with ddp+no_ckpt")
assert objects_are_equal(expected_losses, final_losses, raise_exception=True)
def check(exp, res):
assert list(exp.keys()) == list(res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}"
rtol = 1e-4
atol = 1e-5
if with_model2 and mixed_precision and torch_version() >= (1, 9, 0):
# On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
# larger errors.
rtol = 1e-3
atol = 1e-4
for key in exp.keys():
exp_loss = exp[key]
res_loss = res[key]
torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol)
check(expected_losses, final_losses)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment