"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "32b85dfa8d4a5fa54469ddc72be89d827c1ee9d6"
Unverified Commit 093757db authored by Vivek Kumar's avatar Vivek Kumar Committed by GitHub
Browse files

Port test_datasets_samplers.py to pytest (#4037)

parent 13ed657d
...@@ -2,7 +2,7 @@ import contextlib ...@@ -2,7 +2,7 @@ import contextlib
import sys import sys
import os import os
import torch import torch
import unittest import pytest
from torchvision import io from torchvision import io
from torchvision.datasets.samplers import ( from torchvision.datasets.samplers import (
...@@ -38,13 +38,13 @@ def get_list_of_videos(num_videos=5, sizes=None, fps=None): ...@@ -38,13 +38,13 @@ def get_list_of_videos(num_videos=5, sizes=None, fps=None):
yield names yield names
@unittest.skipIf(not io.video._av_available(), "this test requires av") @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
class Tester(unittest.TestCase): class TestDatasetsSamplers:
def test_random_clip_sampler(self): def test_random_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) sampler = RandomClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 3 * 3) assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor') videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True) v_idxs, count = torch.unique(videos, return_counts=True)
...@@ -55,10 +55,10 @@ class Tester(unittest.TestCase): ...@@ -55,10 +55,10 @@ class Tester(unittest.TestCase):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) sampler = RandomClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 2 + 3 + 3) assert len(sampler) == 2 + 3 + 3
indices = list(iter(sampler)) indices = list(iter(sampler))
self.assertIn(0, indices) assert 0 in indices
self.assertIn(1, indices) assert 1 in indices
# remove elements of the first video, to simplify testing # remove elements of the first video, to simplify testing
indices.remove(0) indices.remove(0)
indices.remove(1) indices.remove(1)
...@@ -72,7 +72,7 @@ class Tester(unittest.TestCase): ...@@ -72,7 +72,7 @@ class Tester(unittest.TestCase):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3) sampler = UniformClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 3 * 3) assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor') videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True) v_idxs, count = torch.unique(videos, return_counts=True)
...@@ -84,7 +84,7 @@ class Tester(unittest.TestCase): ...@@ -84,7 +84,7 @@ class Tester(unittest.TestCase):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3) sampler = UniformClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 3 * 3) assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])) assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
...@@ -100,7 +100,7 @@ class Tester(unittest.TestCase): ...@@ -100,7 +100,7 @@ class Tester(unittest.TestCase):
group_size=3, group_size=3,
) )
indices = torch.tensor(list(iter(distributed_sampler_rank0))) indices = torch.tensor(list(iter(distributed_sampler_rank0)))
self.assertEqual(len(distributed_sampler_rank0), 6) assert len(distributed_sampler_rank0) == 6
assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14])) assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))
distributed_sampler_rank1 = DistributedSampler( distributed_sampler_rank1 = DistributedSampler(
...@@ -110,9 +110,9 @@ class Tester(unittest.TestCase): ...@@ -110,9 +110,9 @@ class Tester(unittest.TestCase):
group_size=3, group_size=3,
) )
indices = torch.tensor(list(iter(distributed_sampler_rank1))) indices = torch.tensor(list(iter(distributed_sampler_rank1)))
self.assertEqual(len(distributed_sampler_rank1), 6) assert len(distributed_sampler_rank1) == 6
assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4])) assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment