Unverified Commit 0e7ae64b authored by DevPranjal's avatar DevPranjal Committed by GitHub
Browse files

Port tests in test_transforms_video.py to pytest (#4040)

parent fb2598b8
import torch
from torchvision.transforms import Compose
import unittest
import pytest
import random
import numpy as np
import warnings
......@@ -17,7 +17,7 @@ with warnings.catch_warnings(record=True):
import torchvision.transforms._transforms_video as transforms
class TestVideoTransforms(unittest.TestCase):
class TestVideoTransforms():
def test_random_crop_video(self):
numFrames = random.randint(4, 128)
......@@ -30,8 +30,8 @@ class TestVideoTransforms(unittest.TestCase):
transforms.ToTensorVideo(),
transforms.RandomCropVideo((oheight, owidth)),
])(clip)
self.assertEqual(result.size(2), oheight)
self.assertEqual(result.size(3), owidth)
assert result.size(2) == oheight
assert result.size(3) == owidth
transforms.RandomCropVideo((oheight, owidth)).__repr__()
......@@ -46,8 +46,8 @@ class TestVideoTransforms(unittest.TestCase):
transforms.ToTensorVideo(),
transforms.RandomResizedCropVideo((oheight, owidth)),
])(clip)
self.assertEqual(result.size(2), oheight)
self.assertEqual(result.size(3), owidth)
assert result.size(2) == oheight
assert result.size(3) == owidth
transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()
......@@ -70,7 +70,7 @@ class TestVideoTransforms(unittest.TestCase):
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(result.sum().item(), 0, msg)
assert result.sum().item() == 0, msg
oheight += 1
owidth += 1
......@@ -82,7 +82,7 @@ class TestVideoTransforms(unittest.TestCase):
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(sum1.item() > 1, True, msg)
assert sum1.item() > 1, msg
oheight += 1
owidth += 1
......@@ -94,28 +94,29 @@ class TestVideoTransforms(unittest.TestCase):
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertTrue(sum2.item() > 1, msg)
self.assertTrue(sum2.item() > sum1.item(), msg)
assert sum2.item() > 1, msg
assert sum2.item() > sum1.item(), msg
@unittest.skipIf(stats is None, 'scipy.stats is not available')
def test_normalize_video(self):
@pytest.mark.skipif(stats is None, reason='scipy.stats is not available')
@pytest.mark.parametrize('channels', [1, 3])
def test_normalize_video(self, channels):
def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
return p_value > 0.0001
random_state = random.getstate()
random.seed(42)
for channels in [1, 3]:
numFrames = random.randint(4, 128)
height = random.randint(32, 256)
width = random.randint(32, 256)
mean = random.random()
std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip)
self.assertTrue(samples_from_standard_normal(normalized))
numFrames = random.randint(4, 128)
height = random.randint(32, 256)
width = random.randint(32, 256)
mean = random.random()
std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip)
assert samples_from_standard_normal(normalized)
random.setstate(random_state)
# Checking the optional in-place behaviour
......@@ -129,11 +130,11 @@ class TestVideoTransforms(unittest.TestCase):
numFrames, height, width = 64, 4, 4
trans = transforms.ToTensorVideo()
with self.assertRaises(TypeError):
with pytest.raises(TypeError):
trans(np.random.rand(numFrames, height, width, 1).tolist())
trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
trans(torch.ones((height, width, 3), dtype=torch.uint8))
trans(torch.ones((width, 3), dtype=torch.uint8))
......@@ -141,7 +142,7 @@ class TestVideoTransforms(unittest.TestCase):
trans.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available')
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
def test_random_horizontal_flip_video(self):
random_state = random.getstate()
random.seed(42)
......@@ -157,7 +158,7 @@ class TestVideoTransforms(unittest.TestCase):
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
assert p_value > 0.0001
num_samples = 250
num_horizontal = 0
......@@ -168,10 +169,10 @@ class TestVideoTransforms(unittest.TestCase):
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
assert p_value > 0.0001
transforms.RandomHorizontalFlipVideo().__repr__()
if __name__ == '__main__':
unittest.main()
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment