Commit 00ce2d0f authored by soumith's avatar soumith Committed by Soumith Chintala
Browse files

refactor download and md5-checking utilities

parent c7a39ba9
...@@ -11,6 +11,8 @@ if sys.version_info[0] == 2: ...@@ -11,6 +11,8 @@ if sys.version_info[0] == 2:
else: else:
import pickle import pickle
import .utils as utils
class CIFAR10(data.Dataset): class CIFAR10(data.Dataset):
base_folder = 'cifar-10-batches-py' base_folder = 'cifar-10-batches-py'
...@@ -29,7 +31,9 @@ class CIFAR10(data.Dataset): ...@@ -29,7 +31,9 @@ class CIFAR10(data.Dataset):
['test_batch', '40351d587109b95175f43aff81a1287e'], ['test_batch', '40351d587109b95175f43aff81a1287e'],
] ]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): def __init__(self, root, train=True,
transform=None, target_transform=None,
download=False):
self.root = root self.root = root
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
...@@ -106,55 +110,33 @@ class CIFAR10(data.Dataset): ...@@ -106,55 +110,33 @@ class CIFAR10(data.Dataset):
return 10000 return 10000
def _check_integrity(self): def _check_integrity(self):
import hashlib
root = self.root root = self.root
for fentry in (self.train_list + self.test_list): for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1] filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename) fpath = os.path.join(root, self.base_folder, filename)
if not os.path.isfile(fpath): if not utils.check_integrity(fpath, md5):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False return False
return True return True
def download(self): def download(self):
from six.moves import urllib
import tarfile import tarfile
import hashlib
root = self.root
fpath = os.path.join(root, self.filename)
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
# downloads file root = self.root
if os.path.isfile(fpath) and \
hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5: # download
print('Using downloaded file: ' + fpath) utils.download(self.url, root, self.filename, self.tgz_md5)
else:
print('Downloading ' + self.url + ' to ' + fpath)
urllib.request.urlretrieve(self.url, fpath)
# extract file # extract file
cwd = os.getcwd() cwd = os.getcwd()
print('Extracting tar file') tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
tar = tarfile.open(fpath, "r:gz")
os.chdir(root) os.chdir(root)
tar.extractall() tar.extractall()
tar.close() tar.close()
os.chdir(cwd) os.chdir(cwd)
print('Done!')
class CIFAR100(CIFAR10): class CIFAR100(CIFAR10):
......
def check_integrity(fpath, md5):
import hashlib
if not os.path.isfile(fpath):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False
return True
def download(url, root, filename, md5=None):
from six.moves import urllib
fpath = os.path.join(root, filename)
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
# downloads file
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(url, fpath)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment