"git@developer.sourcefind.cn:norm/vllm.git" did not exist on "6aef2278f457e564c17a84465b8b4e74986ccd3c"
Unverified Commit c59f0474 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

[WIP] Add tests for datasets (#966)

* WIP

* WIP: minor improvements

* Add tests

* Fix typo

* Use download_and_extract on caltech, cifar and omniglot

* Add a print message during extraction

* Remove EMNIST from test
parent 2b3a1b6d
import PIL
import shutil
import tempfile
import unittest
import torchvision
class Tester(unittest.TestCase):
def test_mnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.MNIST(tmp_dir, download=True)
self.assertEqual(len(dataset), 60000)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
def test_kmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.KMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
def test_fashionmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
if __name__ == '__main__':
unittest.main()
...@@ -3,6 +3,9 @@ import shutil ...@@ -3,6 +3,9 @@ import shutil
import tempfile import tempfile
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
import unittest import unittest
import zipfile
import tarfile
import gzip
TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'grace_hopper_517x606.jpg') 'assets', 'grace_hopper_517x606.jpg')
...@@ -41,6 +44,47 @@ class Tester(unittest.TestCase): ...@@ -41,6 +44,47 @@ class Tester(unittest.TestCase):
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.' assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
shutil.rmtree(temp_dir) shutil.rmtree(temp_dir)
def test_extract_zip(self):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content')
utils.extract_file(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content'
shutil.rmtree(temp_dir)
def test_extract_tar(self):
for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_file(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)
def test_extract_gzip(self):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode())
utils.extract_file(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
assert os.path.exists(f_name)
with open(os.path.join(f_name), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import os.path import os.path
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_url, makedir_exist_ok from .utils import download_and_extract, makedir_exist_ok
class Caltech101(VisionDataset): class Caltech101(VisionDataset):
...@@ -109,28 +109,21 @@ class Caltech101(VisionDataset): ...@@ -109,28 +109,21 @@ class Caltech101(VisionDataset):
return len(self.index) return len(self.index)
def download(self): def download(self):
import tarfile
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root, self.root,
"101_ObjectCategories.tar.gz", "101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9") "b224c7392d521a49829488ab0f1120d9")
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root, self.root,
"101_Annotations.tar", "101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91") "6f83eeb1f24d99cab4eb377263132c91")
# extract file
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
tar.extractall(path=self.root)
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)
def extra_repr(self): def extra_repr(self):
return "Target type: {target_type}".format(**self.__dict__) return "Target type: {target_type}".format(**self.__dict__)
...@@ -204,17 +197,12 @@ class Caltech256(VisionDataset): ...@@ -204,17 +197,12 @@ class Caltech256(VisionDataset):
return len(self.index) return len(self.index)
def download(self): def download(self):
import tarfile
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root, self.root,
"256_ObjectCategories.tar", "256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d") "67b4f42ca05d46448c6bb8ecd2220f6d")
# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)
...@@ -11,7 +11,7 @@ else: ...@@ -11,7 +11,7 @@ else:
import pickle import pickle
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_url, check_integrity from .utils import check_integrity, download_and_extract
class CIFAR10(VisionDataset): class CIFAR10(VisionDataset):
...@@ -144,17 +144,10 @@ class CIFAR10(VisionDataset): ...@@ -144,17 +144,10 @@ class CIFAR10(VisionDataset):
return True return True
def download(self): def download(self):
import tarfile
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
download_and_extract(self.url, self.root, self.filename, self.tgz_md5)
download_url(self.url, self.root, self.filename, self.tgz_md5)
# extract file
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
def extra_repr(self): def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test") return "Split: {}".format("Train" if self.train is True else "Test")
......
...@@ -4,11 +4,10 @@ import warnings ...@@ -4,11 +4,10 @@ import warnings
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
import gzip
import numpy as np import numpy as np
import torch import torch
import codecs import codecs
from .utils import download_url, makedir_exist_ok from .utils import download_and_extract, extract_file, makedir_exist_ok
class MNIST(VisionDataset): class MNIST(VisionDataset):
...@@ -120,15 +119,6 @@ class MNIST(VisionDataset): ...@@ -120,15 +119,6 @@ class MNIST(VisionDataset):
os.path.exists(os.path.join(self.processed_folder, os.path.exists(os.path.join(self.processed_folder,
self.test_file))) self.test_file)))
@staticmethod
def extract_gzip(gzip_path, remove_finished=False):
print('Extracting {}'.format(gzip_path))
with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(gzip_path) as zip_f:
out_f.write(zip_f.read())
if remove_finished:
os.unlink(gzip_path)
def download(self): def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already.""" """Download the MNIST data if it doesn't exist in processed_folder already."""
...@@ -141,9 +131,7 @@ class MNIST(VisionDataset): ...@@ -141,9 +131,7 @@ class MNIST(VisionDataset):
# download files # download files
for url in self.urls: for url in self.urls:
filename = url.rpartition('/')[2] filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename) download_and_extract(url, root=self.raw_folder, filename=filename)
download_url(url, root=self.raw_folder, filename=filename, md5=None)
self.extract_gzip(gzip_path=file_path, remove_finished=True)
# process and save as torch files # process and save as torch files
print('Processing...') print('Processing...')
...@@ -262,7 +250,6 @@ class EMNIST(MNIST): ...@@ -262,7 +250,6 @@ class EMNIST(MNIST):
def download(self): def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already.""" """Download the EMNIST data if it doesn't exist in processed_folder already."""
import shutil import shutil
import zipfile
if self._check_exists(): if self._check_exists():
return return
...@@ -271,18 +258,12 @@ class EMNIST(MNIST): ...@@ -271,18 +258,12 @@ class EMNIST(MNIST):
makedir_exist_ok(self.processed_folder) makedir_exist_ok(self.processed_folder)
# download files # download files
filename = self.url.rpartition('/')[2] print('Downloading and extracting zip archive')
file_path = os.path.join(self.raw_folder, filename) download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True)
download_url(self.url, root=self.raw_folder, filename=filename, md5=None)
print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(self.raw_folder)
os.unlink(file_path)
gzip_folder = os.path.join(self.raw_folder, 'gzip') gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder): for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'): if gzip_file.endswith('.gz'):
self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file)) extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder)
# process and save as torch files # process and save as torch files
for split in self.splits: for split in self.splits:
......
...@@ -3,7 +3,7 @@ from PIL import Image ...@@ -3,7 +3,7 @@ from PIL import Image
from os.path import join from os.path import join
import os import os
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_url, check_integrity, list_dir, list_files from .utils import download_and_extract, check_integrity, list_dir, list_files
class Omniglot(VisionDataset): class Omniglot(VisionDataset):
...@@ -81,8 +81,6 @@ class Omniglot(VisionDataset): ...@@ -81,8 +81,6 @@ class Omniglot(VisionDataset):
return True return True
def download(self): def download(self):
import zipfile
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
...@@ -90,10 +88,7 @@ class Omniglot(VisionDataset): ...@@ -90,10 +88,7 @@ class Omniglot(VisionDataset):
filename = self._get_target_folder() filename = self._get_target_folder()
zip_filename = filename + '.zip' zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename url = self.download_url_prefix + '/' + zip_filename
download_url(url, self.root, zip_filename, self.zips_md5[filename]) download_and_extract(url, self.root, zip_filename, self.zips_md5[filename])
print('Extracting downloaded file: ' + join(self.root, zip_filename))
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file:
zip_file.extractall(self.root)
def _get_target_folder(self): def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation' return 'images_background' if self.background else 'images_evaluation'
import os import os
import os.path import os.path
import hashlib import hashlib
import gzip
import errno import errno
import tarfile
import zipfile
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
...@@ -189,3 +193,46 @@ def _save_response_content(response, destination, chunk_size=32768): ...@@ -189,3 +193,46 @@ def _save_response_content(response, destination, chunk_size=32768):
progress += len(chunk) progress += len(chunk)
pbar.update(progress - pbar.n) pbar.update(progress - pbar.n)
pbar.close() pbar.close()
def _is_tar(filename):
return filename.endswith(".tar")
def _is_targz(filename):
return filename.endswith(".tar.gz")
def _is_gzip(filename):
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
def _is_zip(filename):
return filename.endswith(".zip")
def extract_file(from_path, to_path, remove_finished=False):
if _is_tar(from_path):
with tarfile.open(from_path, 'r:') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))
if remove_finished:
os.unlink(from_path)
def download_and_extract(url, root, filename, md5=None, remove_finished=False):
download_url(url, root, filename, md5)
print("Extracting {} to {}".format(os.path.join(root, filename), root))
extract_file(os.path.join(root, filename), root, remove_finished)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment