Unverified Commit 23295fbb authored by Edgar Andrés Margffoy Tuay's avatar Edgar Andrés Margffoy Tuay Committed by GitHub
Browse files

PR: Add UCF101 dataset tests (#2548)

* Add fake data generator for UCF101

* Minor error correction

* Reduce total number of categories

* Fix naming

* Increase length

* Store in uint8

* Close fds

* Add assertGreater

* Add dimension tests

* Use numel instead of size

* Iterate over folds and splits
parent c2bbefc2
...@@ -7,6 +7,9 @@ import PIL ...@@ -7,6 +7,9 @@ import PIL
import torch import torch
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
import pickle import pickle
import random
from itertools import cycle
from torchvision.io.video import write_video
@contextlib.contextmanager @contextlib.contextmanager
...@@ -265,3 +268,47 @@ def voc_root(): ...@@ -265,3 +268,47 @@ def voc_root():
f.write('test') f.write('test')
yield tmp_dir yield tmp_dir
@contextlib.contextmanager
def ucf101_root():
with get_tmp_dir() as tmp_dir:
ucf_dir = os.path.join(tmp_dir, 'UCF-101')
video_dir = os.path.join(ucf_dir, 'video')
annotations = os.path.join(ucf_dir, 'annotations')
os.makedirs(ucf_dir)
os.makedirs(video_dir)
os.makedirs(annotations)
fold_files = []
for split in {'train', 'test'}:
for fold in range(1, 4):
fold_file = '{:s}list{:02d}.txt'.format(split, fold)
fold_files.append(os.path.join(annotations, fold_file))
file_handles = [open(x, 'w') for x in fold_files]
file_iter = cycle(file_handles)
for i in range(0, 2):
current_class = 'class_{0}'.format(i + 1)
class_dir = os.path.join(video_dir, current_class)
os.makedirs(class_dir)
for group in range(0, 3):
for clip in range(0, 4):
# Save sample file
clip_name = 'v_{0}_g{1}_c{2}.avi'.format(
current_class, group, clip)
clip_path = os.path.join(class_dir, clip_name)
length = random.randrange(10, 21)
this_clip = torch.randint(
0, 256, (length * 25, 320, 240, 3), dtype=torch.uint8)
write_video(clip_path, this_clip, 25)
# Add to annotations
ann_file = next(file_iter)
ann_file.write('{0}\n'.format(
os.path.join(current_class, clip_name)))
# Close all file descriptors
for f in file_handles:
f.close()
yield (video_dir, annotations)
...@@ -9,7 +9,7 @@ from torch._utils_internal import get_file_path_2 ...@@ -9,7 +9,7 @@ from torch._utils_internal import get_file_path_2
import torchvision import torchvision
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root cityscapes_root, svhn_root, voc_root, ucf101_root
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
...@@ -19,6 +19,12 @@ try: ...@@ -19,6 +19,12 @@ try:
except ImportError: except ImportError:
HAS_SCIPY = False HAS_SCIPY = False
try:
import av
HAS_PYAV = True
except ImportError:
HAS_PYAV = False
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1): def generic_classification_dataset_test(self, dataset, num_images=1):
...@@ -254,6 +260,26 @@ class Tester(unittest.TestCase): ...@@ -254,6 +260,26 @@ class Tester(unittest.TestCase):
}] }]
}}) }})
@unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
def test_ucf101(self):
with ucf101_root() as (root, ann_root):
for split in {True, False}:
for fold in range(1, 4):
for length in {10, 15, 20}:
dataset = torchvision.datasets.UCF101(
root, ann_root, length, fold=fold, train=split)
self.assertGreater(len(dataset), 0)
video, audio, label = dataset[0]
self.assertEqual(video.size(), (length, 320, 240, 3))
self.assertEqual(audio.numel(), 0)
self.assertEqual(label, 0)
video, audio, label = dataset[len(dataset) - 1]
self.assertEqual(video.size(), (length, 320, 240, 3))
self.assertEqual(audio.numel(), 0)
self.assertEqual(label, 1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment