Unverified Commit c59f0474 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

[WIP] Add tests for datasets (#966)

* WIP

* WIP: minor improvements

* Add tests

* Fix typo

* Use download_and_extract on caltech, cifar and omniglot

* Add a print message during extraction

* Remove EMNIST from test
parent 2b3a1b6d
import PIL
import shutil
import tempfile
import unittest
import torchvision
class Tester(unittest.TestCase):
def test_mnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.MNIST(tmp_dir, download=True)
self.assertEqual(len(dataset), 60000)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
def test_kmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.KMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
def test_fashionmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
if __name__ == '__main__':
unittest.main()
......@@ -3,6 +3,9 @@ import shutil
import tempfile
import torchvision.datasets.utils as utils
import unittest
import zipfile
import tarfile
import gzip
TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'grace_hopper_517x606.jpg')
......@@ -41,6 +44,47 @@ class Tester(unittest.TestCase):
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
shutil.rmtree(temp_dir)
def test_extract_zip(self):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content')
utils.extract_file(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content'
shutil.rmtree(temp_dir)
def test_extract_tar(self):
for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_file(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)
def test_extract_gzip(self):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode())
utils.extract_file(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
assert os.path.exists(f_name)
with open(os.path.join(f_name), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)
if __name__ == '__main__':
unittest.main()
......@@ -4,7 +4,7 @@ import os
import os.path
from .vision import VisionDataset
from .utils import download_url, makedir_exist_ok
from .utils import download_and_extract, makedir_exist_ok
class Caltech101(VisionDataset):
......@@ -109,28 +109,21 @@ class Caltech101(VisionDataset):
return len(self.index)
def download(self):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
return
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91")
# extract file
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
tar.extractall(path=self.root)
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)
def extra_repr(self):
return "Target type: {target_type}".format(**self.__dict__)
......@@ -204,17 +197,12 @@ class Caltech256(VisionDataset):
return len(self.index)
def download(self):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
return
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d")
# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)
......@@ -11,7 +11,7 @@ else:
import pickle
from .vision import VisionDataset
from .utils import download_url, check_integrity
from .utils import check_integrity, download_and_extract
class CIFAR10(VisionDataset):
......@@ -144,17 +144,10 @@ class CIFAR10(VisionDataset):
return True
def download(self):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
return
download_url(self.url, self.root, self.filename, self.tgz_md5)
# extract file
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
download_and_extract(self.url, self.root, self.filename, self.tgz_md5)
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
......
......@@ -4,11 +4,10 @@ import warnings
from PIL import Image
import os
import os.path
import gzip
import numpy as np
import torch
import codecs
from .utils import download_url, makedir_exist_ok
from .utils import download_and_extract, extract_file, makedir_exist_ok
class MNIST(VisionDataset):
......@@ -120,15 +119,6 @@ class MNIST(VisionDataset):
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))
@staticmethod
def extract_gzip(gzip_path, remove_finished=False):
print('Extracting {}'.format(gzip_path))
with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(gzip_path) as zip_f:
out_f.write(zip_f.read())
if remove_finished:
os.unlink(gzip_path)
def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
......@@ -141,9 +131,7 @@ class MNIST(VisionDataset):
# download files
for url in self.urls:
filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
download_url(url, root=self.raw_folder, filename=filename, md5=None)
self.extract_gzip(gzip_path=file_path, remove_finished=True)
download_and_extract(url, root=self.raw_folder, filename=filename)
# process and save as torch files
print('Processing...')
......@@ -262,7 +250,6 @@ class EMNIST(MNIST):
def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
import shutil
import zipfile
if self._check_exists():
return
......@@ -271,18 +258,12 @@ class EMNIST(MNIST):
makedir_exist_ok(self.processed_folder)
# download files
filename = self.url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
download_url(self.url, root=self.raw_folder, filename=filename, md5=None)
print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(self.raw_folder)
os.unlink(file_path)
print('Downloading and extracting zip archive')
download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True)
gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))
extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder)
# process and save as torch files
for split in self.splits:
......
......@@ -3,7 +3,7 @@ from PIL import Image
from os.path import join
import os
from .vision import VisionDataset
from .utils import download_url, check_integrity, list_dir, list_files
from .utils import download_and_extract, check_integrity, list_dir, list_files
class Omniglot(VisionDataset):
......@@ -81,8 +81,6 @@ class Omniglot(VisionDataset):
return True
def download(self):
import zipfile
if self._check_integrity():
print('Files already downloaded and verified')
return
......@@ -90,10 +88,7 @@ class Omniglot(VisionDataset):
filename = self._get_target_folder()
zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename
download_url(url, self.root, zip_filename, self.zips_md5[filename])
print('Extracting downloaded file: ' + join(self.root, zip_filename))
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file:
zip_file.extractall(self.root)
download_and_extract(url, self.root, zip_filename, self.zips_md5[filename])
def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation'
import os
import os.path
import hashlib
import gzip
import errno
import tarfile
import zipfile
from torch.utils.model_zoo import tqdm
......@@ -189,3 +193,46 @@ def _save_response_content(response, destination, chunk_size=32768):
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()
def _is_tar(filename):
return filename.endswith(".tar")
def _is_targz(filename):
return filename.endswith(".tar.gz")
def _is_gzip(filename):
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
def _is_zip(filename):
return filename.endswith(".zip")
def extract_file(from_path, to_path, remove_finished=False):
if _is_tar(from_path):
with tarfile.open(from_path, 'r:') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))
if remove_finished:
os.unlink(from_path)
def download_and_extract(url, root, filename, md5=None, remove_finished=False):
download_url(url, root, filename, md5)
print("Extracting {} to {}".format(os.path.join(root, filename), root))
extract_file(os.path.join(root, filename), root, remove_finished)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment