Commit 7716aba5 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

[WIP] Add test for ImageNet (#976)

* added fake data

* fixed fake data

* renamed extract and download methods and added functionality

* added raw fake data

* refactored imagenet and added test

* flake8

* added fake devkit and mocked download_url

* reversed uncommenting

* added mock to CI

* fixed tests for imagefolder

* flake8
parent 3d561039
...@@ -7,7 +7,7 @@ import os.path ...@@ -7,7 +7,7 @@ import os.path
import numpy as np import numpy as np
import torch import torch
import codecs import codecs
from .utils import download_and_extract, extract_file, makedir_exist_ok from .utils import download_and_extract_archive, extract_archive, makedir_exist_ok
class MNIST(VisionDataset): class MNIST(VisionDataset):
...@@ -131,7 +131,7 @@ class MNIST(VisionDataset): ...@@ -131,7 +131,7 @@ class MNIST(VisionDataset):
# download files # download files
for url in self.urls: for url in self.urls:
filename = url.rpartition('/')[2] filename = url.rpartition('/')[2]
download_and_extract(url, root=self.raw_folder, filename=filename) download_and_extract_archive(url, download_root=self.raw_folder, filename=filename)
# process and save as torch files # process and save as torch files
print('Processing...') print('Processing...')
...@@ -259,11 +259,12 @@ class EMNIST(MNIST): ...@@ -259,11 +259,12 @@ class EMNIST(MNIST):
# download files # download files
print('Downloading and extracting zip archive') print('Downloading and extracting zip archive')
download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True) download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip",
remove_finished=True)
gzip_folder = os.path.join(self.raw_folder, 'gzip') gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder): for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'): if gzip_file.endswith('.gz'):
extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder) extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder)
# process and save as torch files # process and save as torch files
for split in self.splits: for split in self.splits:
......
...@@ -3,7 +3,7 @@ from PIL import Image ...@@ -3,7 +3,7 @@ from PIL import Image
from os.path import join from os.path import join
import os import os
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_and_extract, check_integrity, list_dir, list_files from .utils import download_and_extract_archive, check_integrity, list_dir, list_files
class Omniglot(VisionDataset): class Omniglot(VisionDataset):
...@@ -88,7 +88,7 @@ class Omniglot(VisionDataset): ...@@ -88,7 +88,7 @@ class Omniglot(VisionDataset):
filename = self._get_target_folder() filename = self._get_target_folder()
zip_filename = filename + '.zip' zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename url = self.download_url_prefix + '/' + zip_filename
download_and_extract(url, self.root, zip_filename, self.zips_md5[filename]) download_and_extract_archive(url, self.root, zip_filename, self.zips_md5[filename])
def _get_target_folder(self): def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation' return 'images_background' if self.background else 'images_evaluation'
...@@ -5,7 +5,7 @@ import os.path ...@@ -5,7 +5,7 @@ import os.path
import numpy as np import numpy as np
from .vision import VisionDataset from .vision import VisionDataset
from .utils import check_integrity, download_and_extract from .utils import check_integrity, download_and_extract_archive
class STL10(VisionDataset): class STL10(VisionDataset):
...@@ -152,7 +152,7 @@ class STL10(VisionDataset): ...@@ -152,7 +152,7 @@ class STL10(VisionDataset):
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
download_and_extract(self.url, self.root, self.filename, self.tgz_md5) download_and_extract_archive(self.url, self.root, self.filename, self.tgz_md5)
def extra_repr(self): def extra_repr(self):
return "Split: {split}".format(**self.__dict__) return "Split: {split}".format(**self.__dict__)
......
...@@ -38,8 +38,7 @@ def check_integrity(fpath, md5=None): ...@@ -38,8 +38,7 @@ def check_integrity(fpath, md5=None):
return False return False
if md5 is None: if md5 is None:
return True return True
else: return check_md5(fpath, md5)
return check_md5(fpath, md5)
def makedir_exist_ok(dirpath): def makedir_exist_ok(dirpath):
...@@ -74,7 +73,7 @@ def download_url(url, root, filename=None, md5=None): ...@@ -74,7 +73,7 @@ def download_url(url, root, filename=None, md5=None):
makedir_exist_ok(root) makedir_exist_ok(root)
# downloads file # downloads file
if os.path.isfile(fpath) and check_integrity(fpath, md5): if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath) print('Using downloaded and verified file: ' + fpath)
else: else:
try: try:
...@@ -211,9 +210,12 @@ def _is_zip(filename): ...@@ -211,9 +210,12 @@ def _is_zip(filename):
return filename.endswith(".zip") return filename.endswith(".zip")
def extract_file(from_path, to_path, remove_finished=False): def extract_archive(from_path, to_path=None, remove_finished=False):
if to_path is None:
to_path = os.path.dirname(from_path)
if _is_tar(from_path): if _is_tar(from_path):
with tarfile.open(from_path, 'r:') as tar: with tarfile.open(from_path, 'r') as tar:
tar.extractall(path=to_path) tar.extractall(path=to_path)
elif _is_targz(from_path): elif _is_targz(from_path):
with tarfile.open(from_path, 'r:gz') as tar: with tarfile.open(from_path, 'r:gz') as tar:
...@@ -229,10 +231,19 @@ def extract_file(from_path, to_path, remove_finished=False): ...@@ -229,10 +231,19 @@ def extract_file(from_path, to_path, remove_finished=False):
raise ValueError("Extraction of {} not supported".format(from_path)) raise ValueError("Extraction of {} not supported".format(from_path))
if remove_finished: if remove_finished:
os.unlink(from_path) os.remove(from_path)
def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
md5=None, remove_finished=False):
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
def download_and_extract(url, root, filename, md5=None, remove_finished=False): archive = os.path.join(download_root, filename)
download_url(url, root, filename, md5) print("Extracting {} to {}".format(archive, extract_root))
print("Extracting {} to {}".format(os.path.join(root, filename), root)) extract_archive(archive, extract_root, remove_finished)
extract_file(os.path.join(root, filename), root, remove_finished)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment