Commit 7716aba5 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

[WIP] Add test for ImageNet (#976)

* added fake data

* fixed fake data

* renamed extract and download methods and added functionality

* added raw fake data

* refactored imagenet and added test

* flake8

* added fake devkit and mocked download_url

* reversed uncommenting

* added mock to CI

* fixed tests for imagefolder

* flake8
parent 3d561039
......@@ -7,7 +7,7 @@ import os.path
import numpy as np
import torch
import codecs
from .utils import download_and_extract, extract_file, makedir_exist_ok
from .utils import download_and_extract_archive, extract_archive, makedir_exist_ok
class MNIST(VisionDataset):
......@@ -131,7 +131,7 @@ class MNIST(VisionDataset):
# download files
for url in self.urls:
filename = url.rpartition('/')[2]
download_and_extract(url, root=self.raw_folder, filename=filename)
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename)
# process and save as torch files
print('Processing...')
......@@ -259,11 +259,12 @@ class EMNIST(MNIST):
# download files
print('Downloading and extracting zip archive')
download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True)
download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip",
remove_finished=True)
gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder)
extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder)
# process and save as torch files
for split in self.splits:
......
......@@ -3,7 +3,7 @@ from PIL import Image
from os.path import join
import os
from .vision import VisionDataset
from .utils import download_and_extract, check_integrity, list_dir, list_files
from .utils import download_and_extract_archive, check_integrity, list_dir, list_files
class Omniglot(VisionDataset):
......@@ -88,7 +88,7 @@ class Omniglot(VisionDataset):
filename = self._get_target_folder()
zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename
download_and_extract(url, self.root, zip_filename, self.zips_md5[filename])
download_and_extract_archive(url, self.root, zip_filename, self.zips_md5[filename])
def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation'
......@@ -5,7 +5,7 @@ import os.path
import numpy as np
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract
from .utils import check_integrity, download_and_extract_archive
class STL10(VisionDataset):
......@@ -152,7 +152,7 @@ class STL10(VisionDataset):
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract(self.url, self.root, self.filename, self.tgz_md5)
download_and_extract_archive(self.url, self.root, self.filename, self.tgz_md5)
def extra_repr(self):
return "Split: {split}".format(**self.__dict__)
......
......@@ -38,8 +38,7 @@ def check_integrity(fpath, md5=None):
return False
if md5 is None:
return True
else:
return check_md5(fpath, md5)
return check_md5(fpath, md5)
def makedir_exist_ok(dirpath):
......@@ -74,7 +73,7 @@ def download_url(url, root, filename=None, md5=None):
makedir_exist_ok(root)
# downloads file
if os.path.isfile(fpath) and check_integrity(fpath, md5):
if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
try:
......@@ -211,9 +210,12 @@ def _is_zip(filename):
return filename.endswith(".zip")
def extract_file(from_path, to_path, remove_finished=False):
def extract_archive(from_path, to_path=None, remove_finished=False):
if to_path is None:
to_path = os.path.dirname(from_path)
if _is_tar(from_path):
with tarfile.open(from_path, 'r:') as tar:
with tarfile.open(from_path, 'r') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
......@@ -229,10 +231,19 @@ def extract_file(from_path, to_path, remove_finished=False):
raise ValueError("Extraction of {} not supported".format(from_path))
if remove_finished:
os.unlink(from_path)
os.remove(from_path)
def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
md5=None, remove_finished=False):
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
def download_and_extract(url, root, filename, md5=None, remove_finished=False):
download_url(url, root, filename, md5)
print("Extracting {} to {}".format(os.path.join(root, filename), root))
extract_file(os.path.join(root, filename), root, remove_finished)
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment