Unverified Commit 13bd09dd authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Prevent tests from leaking their respective RNG (#4497)



* Add autouse fixture to save and reset RNG in tests

* Add other RNG generators

* delete freeze_rng_state

* Hopefully fix GaussianBlur test

* Alternative fix, probably better

* revert changes to test_models
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 5e8a2116
from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG
import torch import torch
import numpy as np
import random
import pytest import pytest
...@@ -80,3 +82,26 @@ def pytest_sessionfinish(session, exitstatus): ...@@ -80,3 +82,26 @@ def pytest_sessionfinish(session, exitstatus):
# To avoid this, we transform this 5 into a 0 to make testpilot happy. # To avoid this, we transform this 5 into a 0 to make testpilot happy.
if exitstatus == 5: if exitstatus == 5:
session.exitstatus = 0 session.exitstatus = 0
@pytest.fixture(autouse=True)
def prevent_leaking_rng():
# Prevent each test from leaking the rng to all other test when they call
# torch.manual_seed() or random.seed() or np.random.seed().
# Note: the numpy rngs should never leak anyway, as we never use
# np.random.seed() and instead rely on np.random.RandomState instances (see
# issue #4247). We still do it for extra precaution.
torch_rng_state = torch.get_rng_state()
builtin_rng_state = random.getstate()
nunmpy_rng_state = np.random.get_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
yield
torch.set_rng_state(torch_rng_state)
random.setstate(builtin_rng_state)
np.random.set_state(nunmpy_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
...@@ -714,6 +714,7 @@ def test_random_apply(device): ...@@ -714,6 +714,7 @@ def test_random_apply(device):
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize('channels', [1, 3])
def test_gaussian_blur(device, channels, meth_kwargs): def test_gaussian_blur(device, channels, meth_kwargs):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
torch.manual_seed(12)
_test_class_op( _test_class_op(
T.GaussianBlur, meth_kwargs=meth_kwargs, channels=channels, T.GaussianBlur, meth_kwargs=meth_kwargs, channels=channels,
test_exact_match=False, device=device, agg_method="max", tol=tol test_exact_match=False, device=device, agg_method="max", tol=tol
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment