Unverified Commit bbfda424 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix flakiness on StochasticDepth test (#4758)

* Fix flakiness on the TestStochasticDepth test.

* Fix minor bug when p=1.0

* Remove device and dtype setting.
parent 5ea23483
......@@ -1149,13 +1149,15 @@ class TestMasksToBoxes:
class TestStochasticDepth:
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
@pytest.mark.parametrize("mode", ["batch", "row"])
def test_stochastic_depth(self, mode, p):
def test_stochastic_depth_random(self, seed, mode, p):
torch.manual_seed(seed)
stats = pytest.importorskip("scipy.stats")
batch_size = 5
x = torch.ones(size=(batch_size, 3, 4, 4))
layer = ops.StochasticDepth(p=p, mode=mode).to(device=x.device, dtype=x.dtype)
layer = ops.StochasticDepth(p=p, mode=mode)
layer.__repr__()
trials = 250
......@@ -1173,7 +1175,22 @@ class TestStochasticDepth:
num_samples += batch_size
p_value = stats.binom_test(counts, num_samples, p=p)
assert p_value > 0.0001
assert p_value > 0.01
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("p", (0, 1))
@pytest.mark.parametrize("mode", ["batch", "row"])
def test_stochastic_depth(self, seed, mode, p):
torch.manual_seed(seed)
batch_size = 5
x = torch.ones(size=(batch_size, 3, 4, 4))
layer = ops.StochasticDepth(p=p, mode=mode)
out = layer(x)
if p == 0:
assert out.equal(x)
elif p == 1:
assert out.equal(torch.zeros_like(x))
class TestUtils:
......
......@@ -34,7 +34,9 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
else:
size = [1] * input.ndim
noise = torch.empty(size, dtype=input.dtype, device=input.device)
noise = noise.bernoulli_(survival_rate).div_(survival_rate)
noise = noise.bernoulli_(survival_rate)
if survival_rate > 0.0:
noise.div_(survival_rate)
return input * noise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment