Unverified Commit 3a17e339 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove p_value test for RandomHorizontalFlipVideo (#4765)

parent 0f770ac9
...@@ -160,34 +160,16 @@ class TestVideoTransforms: ...@@ -160,34 +160,16 @@ class TestVideoTransforms:
trans.__repr__() trans.__repr__()
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") @pytest.mark.parametrize("p", (0, 1))
def test_random_horizontal_flip_video(self): def test_random_horizontal_flip_video(self, p):
random_state = random.getstate()
random.seed(42)
clip = torch.rand((3, 4, 112, 112), dtype=torch.float) clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
hclip = clip.flip((-1)) hclip = clip.flip((-1))
num_samples = 250 out = transforms.RandomHorizontalFlipVideo(p=p)(clip)
num_horizontal = 0 if p == 0:
for _ in range(num_samples): torch.testing.assert_close(out, clip)
out = transforms.RandomHorizontalFlipVideo()(clip) elif p == 1:
if torch.all(torch.eq(out, hclip)): torch.testing.assert_close(out, hclip)
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip)
if torch.all(torch.eq(out, hclip)):
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001
transforms.RandomHorizontalFlipVideo().__repr__() transforms.RandomHorizontalFlipVideo().__repr__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment