Unverified Commit 512ea299 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Ported SVHN dataset to new test framework (#3661)

* Ported SVHN dataset to new test framework

* Fixed flake8 error and added REQUIRED_PACKAGES=scipy
parent 44460c9c
......@@ -210,23 +210,6 @@ def widerface_root():
yield root
@contextlib.contextmanager
def svhn_root():
import scipy.io as sio
def _make_mat(file):
images = np.zeros((32, 32, 3, 2), dtype=np.uint8)
targets = np.zeros((2,), dtype=np.uint8)
sio.savemat(file, {'X': images, 'y': targets})
with get_tmp_dir() as root:
_make_mat(os.path.join(root, "train_32x32.mat"))
_make_mat(os.path.join(root, "test_32x32.mat"))
_make_mat(os.path.join(root, "extra_32x32.mat"))
yield root
@contextlib.contextmanager
def places365_root(split="train-standard", small=False):
VARIANTS = {
......
......@@ -2,7 +2,6 @@ import contextlib
import sys
import os
import unittest
from unittest import mock
import numpy as np
import PIL
from PIL import Image
......@@ -10,7 +9,7 @@ from torch._utils_internal import get_file_path_2
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import svhn_root, places365_root, widerface_root, stl10_root
from fakedata_generation import places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
......@@ -57,20 +56,6 @@ class DatasetTestcase(unittest.TestCase):
class Tester(DatasetTestcase):
@mock.patch('torchvision.datasets.SVHN._check_integrity')
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
def test_svhn(self, mock_check):
mock_check.return_value = True
with svhn_root() as root:
dataset = torchvision.datasets.SVHN(root, split="train")
self.generic_classification_dataset_test(dataset, num_images=2)
dataset = torchvision.datasets.SVHN(root, split="test")
self.generic_classification_dataset_test(dataset, num_images=2)
dataset = torchvision.datasets.SVHN(root, split="extra")
self.generic_classification_dataset_test(dataset, num_images=2)
def test_places365(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
......@@ -1737,5 +1722,27 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["train"]]
class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SVHN
REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "extra"))
def inject_fake_data(self, tmpdir, config):
import scipy.io as sio
split = config["split"]
num_examples = {
"train": 2,
"test": 3,
"extra": 4,
}.get(split)
file = f"{split}_32x32.mat"
images = np.zeros((32, 32, 3, num_examples), dtype=np.uint8)
targets = np.zeros((num_examples,), dtype=np.uint8)
sio.savemat(os.path.join(tmpdir, file), {'X': images, 'y': targets})
return num_examples
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment