Unverified Commit 85962b97 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[SDP] removing an assert which does not seem always accurate (#625)

parent b0048b28
...@@ -1664,9 +1664,7 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ...@@ -1664,9 +1664,7 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
if recurse: if recurse:
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore
else: else:
return is_bn and not isinstance( return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)
) # type: ignore
pg = None pg = None
if single_rank_pg: if single_rank_pg:
......
...@@ -269,10 +269,10 @@ class ShardedDataParallel(nn.Module): ...@@ -269,10 +269,10 @@ class ShardedDataParallel(nn.Module):
""" If the module trainability has changed, update all the assumptions """ """ If the module trainability has changed, update all the assumptions """
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance) # Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), ( if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False):
"Grads waiting to be reduced: {}".format(self._grad_to_be_reduced) logging.warning(
+ "\nIf this is on purpose (grad accumulation), please use a no_sync() context" "Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
) )
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params)) self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel()) self._trainable_params.sort(key=lambda x: x.numel())
......
...@@ -262,9 +262,9 @@ def test_mixed_types(): ...@@ -262,9 +262,9 @@ def test_mixed_types():
dist.destroy_process_group() dist.destroy_process_group()
def test_train_eval_change(): def run_test_train_eval_change(rank, world_size, file):
# Check that ShardedDDP handles the switch from training to eval properly # Check that ShardedDDP handles the switch from training to eval properly
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) dist.init_process_group(init_method="file://" + file, backend="gloo", rank=rank, world_size=world_size)
model = _get_mlp() model = _get_mlp()
model.train() model.train()
...@@ -288,6 +288,14 @@ def test_train_eval_change(): ...@@ -288,6 +288,14 @@ def test_train_eval_change():
dist.destroy_process_group() dist.destroy_process_group()
def test_train_eval_change():
world_size = 4
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_test_train_eval_change, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices # Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment