Unverified Commit 85962b97 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[SDP] removing an assert which does not seem always accurate (#625)

parent b0048b28
......@@ -1664,9 +1664,7 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
if recurse:
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore
else:
return is_bn and not isinstance(
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)
) # type: ignore
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore
pg = None
if single_rank_pg:
......
......@@ -269,9 +269,9 @@ class ShardedDataParallel(nn.Module):
""" If the module trainability has changed, update all the assumptions """
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), (
"Grads waiting to be reduced: {}".format(self._grad_to_be_reduced)
+ "\nIf this is on purpose (grad accumulation), please use a no_sync() context"
if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False):
logging.warning(
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
)
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
......
......@@ -262,9 +262,9 @@ def test_mixed_types():
dist.destroy_process_group()
def test_train_eval_change():
def run_test_train_eval_change(rank, world_size, file):
# Check that ShardedDDP handles the switch from training to eval properly
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
dist.init_process_group(init_method="file://" + file, backend="gloo", rank=rank, world_size=world_size)
model = _get_mlp()
model.train()
......@@ -288,6 +288,14 @@ def test_train_eval_change():
dist.destroy_process_group()
def test_train_eval_change():
world_size = 4
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_test_train_eval_change, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment