Unverified Commit 05a3941f authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_datasets_samplers.py (#3874)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent c4685e81
...@@ -14,6 +14,7 @@ from torchvision.datasets.video_utils import VideoClips, unfold ...@@ -14,6 +14,7 @@ from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend from torchvision import get_video_backend
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
from _assert_utils import assert_equal
@contextlib.contextmanager @contextlib.contextmanager
...@@ -47,8 +48,8 @@ class Tester(unittest.TestCase): ...@@ -47,8 +48,8 @@ class Tester(unittest.TestCase):
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor') videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True) v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2]))) assert_equal(v_idxs, torch.tensor([0, 1, 2]))
self.assertTrue(count.equal(torch.tensor([3, 3, 3]))) assert_equal(count, torch.tensor([3, 3, 3]))
def test_random_clip_sampler_unequal(self): def test_random_clip_sampler_unequal(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
...@@ -64,8 +65,8 @@ class Tester(unittest.TestCase): ...@@ -64,8 +65,8 @@ class Tester(unittest.TestCase):
indices = torch.tensor(indices) - 2 indices = torch.tensor(indices) - 2
videos = torch.div(indices, 5, rounding_mode='floor') videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True) v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1]))) assert_equal(v_idxs, torch.tensor([0, 1]))
self.assertTrue(count.equal(torch.tensor([3, 3]))) assert_equal(count, torch.tensor([3, 3]))
def test_uniform_clip_sampler(self): def test_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
...@@ -75,9 +76,9 @@ class Tester(unittest.TestCase): ...@@ -75,9 +76,9 @@ class Tester(unittest.TestCase):
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor') videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True) v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2]))) assert_equal(v_idxs, torch.tensor([0, 1, 2]))
self.assertTrue(count.equal(torch.tensor([3, 3, 3]))) assert_equal(count, torch.tensor([3, 3, 3]))
self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))) assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
def test_uniform_clip_sampler_insufficient_clips(self): def test_uniform_clip_sampler_insufficient_clips(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
...@@ -85,7 +86,7 @@ class Tester(unittest.TestCase): ...@@ -85,7 +86,7 @@ class Tester(unittest.TestCase):
sampler = UniformClipSampler(video_clips, 3) sampler = UniformClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 3 * 3) self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
self.assertTrue(indices.equal(torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))) assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
def test_distributed_sampler_and_uniform_clip_sampler(self): def test_distributed_sampler_and_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
...@@ -100,7 +101,7 @@ class Tester(unittest.TestCase): ...@@ -100,7 +101,7 @@ class Tester(unittest.TestCase):
) )
indices = torch.tensor(list(iter(distributed_sampler_rank0))) indices = torch.tensor(list(iter(distributed_sampler_rank0)))
self.assertEqual(len(distributed_sampler_rank0), 6) self.assertEqual(len(distributed_sampler_rank0), 6)
self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 10, 12, 14]))) assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))
distributed_sampler_rank1 = DistributedSampler( distributed_sampler_rank1 = DistributedSampler(
clip_sampler, clip_sampler,
...@@ -110,7 +111,7 @@ class Tester(unittest.TestCase): ...@@ -110,7 +111,7 @@ class Tester(unittest.TestCase):
) )
indices = torch.tensor(list(iter(distributed_sampler_rank1))) indices = torch.tensor(list(iter(distributed_sampler_rank1)))
self.assertEqual(len(distributed_sampler_rank1), 6) self.assertEqual(len(distributed_sampler_rank1), 6)
self.assertTrue(indices.equal(torch.tensor([5, 7, 9, 0, 2, 4]))) assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment