Commit 5eee0117 authored by Max Lübbering's avatar Max Lübbering Committed by Francisco Massa
Browse files

Implemented integrity check (md5 hash) after dataset download (#1456)

* Removed unnecessary class variables.

* The integrity of dataset files is now being checked right after the download finished. Thus making sure that a corrupt file is not being extracted. In case of corruption we throw a RuntimeError.

* Added missing md5 hashes to MNIST, FashionMNIST, KMNIST, EMNIST and QMNIST datasets.

* Removed printing of error message when integrity check failed.
Reformulated error message.

* Reformatted code to be lint conform.

* Fixed formatting in utils.py
parent 97b53f96
...@@ -27,12 +27,14 @@ class MNIST(VisionDataset): ...@@ -27,12 +27,14 @@ class MNIST(VisionDataset):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', resources = [
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
] ]
training_file = 'training.pt' training_file = 'training.pt'
test_file = 'test.pt' test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
...@@ -130,9 +132,9 @@ class MNIST(VisionDataset): ...@@ -130,9 +132,9 @@ class MNIST(VisionDataset):
makedir_exist_ok(self.processed_folder) makedir_exist_ok(self.processed_folder)
# download files # download files
for url in self.urls: for url, md5 in self.resources:
filename = url.rpartition('/')[2] filename = url.rpartition('/')[2]
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename) download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
# process and save as torch files # process and save as torch files
print('Processing...') print('Processing...')
...@@ -172,11 +174,15 @@ class FashionMNIST(MNIST): ...@@ -172,11 +174,15 @@ class FashionMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
urls = [ resources = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', "25c81989df183df01b3e8a0aad5dffbe"),
("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
"bef4ecab320f06d8554ea6380940ec79"),
("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
"bb300cfdad3c16e7a12a480ee83cd310")
] ]
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
...@@ -198,11 +204,11 @@ class KMNIST(MNIST): ...@@ -198,11 +204,11 @@ class KMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
urls = [ resources = [
'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz', ("http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz', ("http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz', ("http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz', ("http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134")
] ]
classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo'] classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']
...@@ -231,6 +237,7 @@ class EMNIST(MNIST): ...@@ -231,6 +237,7 @@ class EMNIST(MNIST):
# https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download # https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download
# is (currently) unavailable # is (currently) unavailable
url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip' url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
def __init__(self, root, split, **kwargs): def __init__(self, root, split, **kwargs):
...@@ -260,7 +267,7 @@ class EMNIST(MNIST): ...@@ -260,7 +267,7 @@ class EMNIST(MNIST):
# download files # download files
print('Downloading and extracting zip archive') print('Downloading and extracting zip archive')
download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip", download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip",
remove_finished=True) remove_finished=True, md5=self.md5)
gzip_folder = os.path.join(self.raw_folder, 'gzip') gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder): for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'): if gzip_file.endswith('.gz'):
...@@ -319,16 +326,24 @@ class QMNIST(MNIST): ...@@ -319,16 +326,24 @@ class QMNIST(MNIST):
subsets = { subsets = {
'train': 'train', 'train': 'train',
'test': 'test', 'test10k': 'test', 'test50k': 'test', 'test': 'test',
'test10k': 'test',
'test50k': 'test',
'nist': 'nist' 'nist': 'nist'
} }
urls = { resources = {
'train': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz', 'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz'], 'ed72d4157d28c017586c42bc6afe6370'),
'test': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz', ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz',
'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz'], '0058f8dd561b90ffdd0f734c6a30e5e4')],
'nist': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz', 'test': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz',
'https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz'] '1394631089c404de565df7b7aeaf9412'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz',
'5b5b05890a5e13444e108efe57b788aa')],
'nist': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz',
'7f124b3b8ab81486c9d8c2749c17f834'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz',
'5ed0e788978e45d4a8bd4b7caec3d79d')]
} }
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
...@@ -351,15 +366,15 @@ class QMNIST(MNIST): ...@@ -351,15 +366,15 @@ class QMNIST(MNIST):
return return
makedir_exist_ok(self.raw_folder) makedir_exist_ok(self.raw_folder)
makedir_exist_ok(self.processed_folder) makedir_exist_ok(self.processed_folder)
urls = self.urls[self.subsets[self.what]] split = self.resources[self.subsets[self.what]]
files = [] files = []
# download data files if not already there # download data files if not already there
for url in urls: for url, md5 in split:
filename = url.rpartition('/')[2] filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename) file_path = os.path.join(self.raw_folder, filename)
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
download_url(url, root=self.raw_folder, filename=filename, md5=None) download_url(url, root=self.raw_folder, filename=filename, md5=md5)
files.append(file_path) files.append(file_path)
# process and save as torch files # process and save as torch files
......
...@@ -27,9 +27,6 @@ class SVHN(VisionDataset): ...@@ -27,9 +27,6 @@ class SVHN(VisionDataset):
downloaded again. downloaded again.
""" """
url = ""
filename = ""
file_md5 = ""
split_list = { split_list = {
'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
......
...@@ -74,10 +74,10 @@ def download_url(url, root, filename=None, md5=None): ...@@ -74,10 +74,10 @@ def download_url(url, root, filename=None, md5=None):
makedir_exist_ok(root) makedir_exist_ok(root)
# downloads file # check if file is already present locally
if check_integrity(fpath, md5): if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath) print('Using downloaded and verified file: ' + fpath)
else: else: # download the file
try: try:
print('Downloading ' + url + ' to ' + fpath) print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve( urllib.request.urlretrieve(
...@@ -95,6 +95,9 @@ def download_url(url, root, filename=None, md5=None): ...@@ -95,6 +95,9 @@ def download_url(url, root, filename=None, md5=None):
) )
else: else:
raise e raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
def list_dir(root, prefix=False): def list_dir(root, prefix=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment