Unverified Commit c90a9214 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Add tests that reset_parameters doesn't change parameter initial value ranges (#2550)



* Add tests for 2528 and 2529
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>

* Update tests/pytorch/test_deferred_init.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update tests/pytorch/test_deferred_init.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4f364c8e
......@@ -28,7 +28,6 @@ dtype = torch.bfloat16
class TestDeferredInit:
@staticmethod
def get_module_args(module):
hidden_size = num_heads * head_dim
......@@ -82,3 +81,45 @@ class TestDeferredInit:
"on CUDA device"
)
del module
@pytest.mark.parametrize("module_type", _core_modules)
def test_reset_parameters_doesnt_change_parameter_stats(
self,
module_type: torch.nn.Module,
) -> None:
"""Test for github issue #2528 and #2529 to ensure that reset_parameters() doesn't change
the parameter mean and std"""
args, kwargs = TestDeferredInit.get_module_args(module_type)
kwargs["device"] = "cuda"
module = module_type(*args, **kwargs)
param_stats = {
name: {"mean": param.mean(), "std": param.std()}
for name, param in module.named_parameters()
}
with torch.no_grad():
module.reset_parameters()
param_stats_after = {
name: {"mean": param.mean(), "std": param.std()}
for name, param in module.named_parameters()
}
for name, stats in param_stats_after.items():
torch.testing.assert_close(
stats["mean"],
param_stats[name]["mean"],
atol=1e-3,
rtol=1e-3,
msg=f"{name} mean changed after reset_parameters",
)
torch.testing.assert_close(
stats["std"],
param_stats[name]["std"],
atol=1e-3,
rtol=1e-3,
msg=f"{name} std changed after reset_parameters",
)
del module
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment