Commit 1fb0ccf7 authored by Vishwak Srinivasan's avatar Vishwak Srinivasan Committed by Soumith Chintala
Browse files

Add progress bar based downloading to MNIST (#535)

parent 0bbb1aa3
...@@ -7,6 +7,7 @@ import errno ...@@ -7,6 +7,7 @@ import errno
import numpy as np import numpy as np
import torch import torch
import codecs import codecs
from .utils import download_url
class MNIST(data.Dataset): class MNIST(data.Dataset):
...@@ -120,12 +121,10 @@ class MNIST(data.Dataset): ...@@ -120,12 +121,10 @@ class MNIST(data.Dataset):
raise raise
for url in self.urls: for url in self.urls:
print('Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2] filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename) file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f: download_url(url, root=os.path.join(self.root, self.raw_folder),
f.write(data.read()) filename=filename, md5=None)
with open(file_path.replace('.gz', ''), 'wb') as out_f, \ with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f: gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read()) out_f.write(zip_f.read())
...@@ -247,13 +246,10 @@ class EMNIST(MNIST): ...@@ -247,13 +246,10 @@ class EMNIST(MNIST):
else: else:
raise raise
print('Downloading ' + self.url)
data = urllib.request.urlopen(self.url)
filename = self.url.rpartition('/')[2] filename = self.url.rpartition('/')[2]
raw_folder = os.path.join(self.root, self.raw_folder) raw_folder = os.path.join(self.root, self.raw_folder)
file_path = os.path.join(raw_folder, filename) file_path = os.path.join(raw_folder, filename)
with open(file_path, 'wb') as f: download_url(self.url, root=file_path, filename=filename, md5=None)
f.write(data.read())
print('Extracting zip archive') print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f: with zipfile.ZipFile(file_path) as zip_f:
......
...@@ -15,7 +15,9 @@ def gen_bar_updater(pbar): ...@@ -15,7 +15,9 @@ def gen_bar_updater(pbar):
return bar_update return bar_update
def check_integrity(fpath, md5): def check_integrity(fpath, md5=None):
if md5 is None:
return True
if not os.path.isfile(fpath): if not os.path.isfile(fpath):
return False return False
md5o = hashlib.md5() md5o = hashlib.md5()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment