Commit 1fb0ccf7 authored by Vishwak Srinivasan's avatar Vishwak Srinivasan Committed by Soumith Chintala
Browse files

Add progress bar based downloading to MNIST (#535)

parent 0bbb1aa3
......@@ -7,6 +7,7 @@ import errno
import numpy as np
import torch
import codecs
from .utils import download_url
class MNIST(data.Dataset):
......@@ -120,12 +121,10 @@ class MNIST(data.Dataset):
raise
for url in self.urls:
print('Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
download_url(url, root=os.path.join(self.root, self.raw_folder),
filename=filename, md5=None)
with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read())
......@@ -247,13 +246,10 @@ class EMNIST(MNIST):
else:
raise
print('Downloading ' + self.url)
data = urllib.request.urlopen(self.url)
filename = self.url.rpartition('/')[2]
raw_folder = os.path.join(self.root, self.raw_folder)
file_path = os.path.join(raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
download_url(self.url, root=file_path, filename=filename, md5=None)
print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f:
......
......@@ -15,7 +15,9 @@ def gen_bar_updater(pbar):
return bar_update
def check_integrity(fpath, md5):
def check_integrity(fpath, md5=None):
if md5 is None:
return True
if not os.path.isfile(fpath):
return False
md5o = hashlib.md5()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment