Commit 32460f52 authored by soumith's avatar soumith Committed by Soumith Chintala
Browse files

refactor svhn to use common utils

parent 00ce2d0f
...@@ -11,7 +11,7 @@ if sys.version_info[0] == 2: ...@@ -11,7 +11,7 @@ if sys.version_info[0] == 2:
else: else:
import pickle import pickle
import .utils as utils from .utils import download_url, check_integrity
class CIFAR10(data.Dataset): class CIFAR10(data.Dataset):
...@@ -114,7 +114,7 @@ class CIFAR10(data.Dataset): ...@@ -114,7 +114,7 @@ class CIFAR10(data.Dataset):
for fentry in (self.train_list + self.test_list): for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1] filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename) fpath = os.path.join(root, self.base_folder, filename)
if not utils.check_integrity(fpath, md5): if not check_integrity(fpath, md5):
return False return False
return True return True
...@@ -126,9 +126,7 @@ class CIFAR10(data.Dataset): ...@@ -126,9 +126,7 @@ class CIFAR10(data.Dataset):
return return
root = self.root root = self.root
download_url(self.url, root, self.filename, self.tgz_md5)
# download
utils.download(self.url, root, self.filename, self.tgz_md5)
# extract file # extract file
cwd = os.getcwd() cwd = os.getcwd()
......
...@@ -6,7 +6,7 @@ import os.path ...@@ -6,7 +6,7 @@ import os.path
import errno import errno
import numpy as np import numpy as np
import sys import sys
from .utils import download_url, check_integrity
class SVHN(data.Dataset): class SVHN(data.Dataset):
url = "" url = ""
...@@ -21,14 +21,16 @@ class SVHN(data.Dataset): ...@@ -21,14 +21,16 @@ class SVHN(data.Dataset):
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
def __init__(self, root, split='train', transform=None, target_transform=None, download=False): def __init__(self, root, split='train',
transform=None, target_transform=None, download=False):
self.root = root self.root = root
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.split = split # training set or test set or extra set self.split = split # training set or test set or extra set
if self.split not in self.split_list: if self.split not in self.split_list:
raise ValueError('Wrong split entered! Please use split="train" or split="extra" or split="test"') raise ValueError('Wrong split entered! Please use split="train" '
'or split="extra" or split="test"')
self.url = self.split_list[split][0] self.url = self.split_list[split][0]
self.filename = self.split_list[split][1] self.filename = self.split_list[split][1]
...@@ -71,41 +73,11 @@ class SVHN(data.Dataset): ...@@ -71,41 +73,11 @@ class SVHN(data.Dataset):
return len(self.data) return len(self.data)
def _check_integrity(self): def _check_integrity(self):
import hashlib
root = self.root root = self.root
md5 = self.split_list[self.split][2] md5 = self.split_list[self.split][2]
fpath = os.path.join(root, self.filename) fpath = os.path.join(root, self.filename)
if not os.path.isfile(fpath): return check_integrity(fpath, md5)
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False
return True
def download(self): def download(self):
from six.moves import urllib md5 = self.split_list[self.split][2]
import tarfile download_url(self.url, self.root, self.filename, md5)
import hashlib
root = self.root
fpath = os.path.join(root, self.filename)
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
if self._check_integrity():
print('Files already downloaded and verified')
return
# downloads file
if os.path.isfile(fpath):
print('Using downloaded file: ' + fpath)
else:
print('Downloading ' + self.url + ' to ' + fpath)
urllib.request.urlretrieve(self.url, fpath)
print ('Downloaded!')
import os
import hashlib
import errno
def check_integrity(fpath, md5): def check_integrity(fpath, md5):
import hashlib
if not os.path.isfile(fpath): if not os.path.isfile(fpath):
return False return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest() md5o = hashlib.md5()
with open(fpath,'rb') as f:
# read in 1MB chunks
for chunk in iter(lambda: f.read(1024 * 1024 * 1024), b''):
md5o.update(chunk)
md5c = md5o.hexdigest()
if md5c != md5: if md5c != md5:
return False return False
return True return True
def download(url, root, filename, md5=None): def download_url(url, root, filename, md5=None):
from six.moves import urllib from six.moves import urllib
fpath = os.path.join(root, filename) fpath = os.path.join(root, filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment