Unverified Commit 99163d4f authored by Freddy Snijder's avatar Freddy Snijder Committed by GitHub
Browse files

Added warn_on_trainable_params_changed constructor parameter to allow the user...

Added warn_on_trainable_params_changed constructor parameter to allow the user to suppress the warning on trainable parameters changed (#886)

* Added warn_on_trainable_params_changed constructor parameter to allow the user to suppress the warning on trainable parameters changed; the default is True and thus the default behavior is unchanged

* Addded parameter documentation
parent 56add6d5
......@@ -63,6 +63,9 @@ class ShardedDataParallel(nn.Module):
reduce_fp16 (bool):
cast the grads to fp16 before reducing. Not needed if the model is already fp16, but will probably improve performance
for multi node jobs using PyTorch AMP. The effect is similar to DDP's fp16_compress_hook_ and will also save some memory.
warn_on_trainable_params_changed (bool):
When set to False no warning will be logged whenever a parameter trainability change has been detected.
Default is True.
.. _fp16_compress_hook: https://pytorch.org/docs/1.8.0/ddp_comm_hooks.html?highlight=fp16#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
......@@ -100,6 +103,7 @@ class ShardedDataParallel(nn.Module):
reduce_buffer_size: int = 2 ** 23,
auto_refresh_trainable: bool = True,
reduce_fp16: bool = False,
warn_on_trainable_params_changed: bool = True,
):
super().__init__()
......@@ -116,6 +120,8 @@ class ShardedDataParallel(nn.Module):
"fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
)
self._warn_on_trainable_params_changed = warn_on_trainable_params_changed
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
self._should_accumulate_grads = False
......@@ -654,9 +660,10 @@ class ShardedDataParallel(nn.Module):
# - the whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._grad_hooks) > 0
if trainability_changed:
if self._warn_on_trainable_params_changed and trainability_changed:
logging.warning(
"ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
"ShardedDDP detected that the trainable params changed, "
"either because of eval/train mode or parameter freezing/unfreeze."
)
self._reference_trainable_mask = trainable_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment