Commit 32460f52 authored by soumith's avatar soumith Committed by Soumith Chintala
Browse files

refactor svhn to use common utils

parent 00ce2d0f
......@@ -11,7 +11,7 @@ if sys.version_info[0] == 2:
else:
import pickle
import .utils as utils
from .utils import download_url, check_integrity
class CIFAR10(data.Dataset):
......@@ -114,7 +114,7 @@ class CIFAR10(data.Dataset):
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not utils.check_integrity(fpath, md5):
if not check_integrity(fpath, md5):
return False
return True
......@@ -126,9 +126,7 @@ class CIFAR10(data.Dataset):
return
root = self.root
# download
utils.download(self.url, root, self.filename, self.tgz_md5)
download_url(self.url, root, self.filename, self.tgz_md5)
# extract file
cwd = os.getcwd()
......
......@@ -6,7 +6,7 @@ import os.path
import errno
import numpy as np
import sys
from .utils import download_url, check_integrity
class SVHN(data.Dataset):
url = ""
......@@ -21,14 +21,16 @@ class SVHN(data.Dataset):
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
def __init__(self, root, split='train',
transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.split = split # training set or test set or extra set
if self.split not in self.split_list:
raise ValueError('Wrong split entered! Please use split="train" or split="extra" or split="test"')
raise ValueError('Wrong split entered! Please use split="train" '
'or split="extra" or split="test"')
self.url = self.split_list[split][0]
self.filename = self.split_list[split][1]
......@@ -71,41 +73,11 @@ class SVHN(data.Dataset):
return len(self.data)
def _check_integrity(self):
import hashlib
root = self.root
md5 = self.split_list[self.split][2]
fpath = os.path.join(root, self.filename)
if not os.path.isfile(fpath):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False
return True
return check_integrity(fpath, md5)
def download(self):
from six.moves import urllib
import tarfile
import hashlib
root = self.root
fpath = os.path.join(root, self.filename)
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
if self._check_integrity():
print('Files already downloaded and verified')
return
# downloads file
if os.path.isfile(fpath):
print('Using downloaded file: ' + fpath)
else:
print('Downloading ' + self.url + ' to ' + fpath)
urllib.request.urlretrieve(self.url, fpath)
print ('Downloaded!')
md5 = self.split_list[self.split][2]
download_url(self.url, self.root, self.filename, md5)
import os
import hashlib
import errno
def check_integrity(fpath, md5):
import hashlib
if not os.path.isfile(fpath):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
md5o = hashlib.md5()
with open(fpath,'rb') as f:
# read in 1MB chunks
for chunk in iter(lambda: f.read(1024 * 1024 * 1024), b''):
md5o.update(chunk)
md5c = md5o.hexdigest()
if md5c != md5:
return False
return True
def download(url, root, filename, md5=None):
def download_url(url, root, filename, md5=None):
from six.moves import urllib
fpath = os.path.join(root, filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment