Commit dac9efae authored by Sanyam Kapoor's avatar Sanyam Kapoor Committed by Francisco Massa
Browse files

Omniglot Dataset (#323)

* Add basic Omniglot dataset loader

* Remove unused import

* Add Omniglot random pair to sample pair of characters

* Precompute random set of pairs, deterministic after object instantiation

* Export OmniglotRandomPair via the datasets module interfact

* Fix naming convention, use sum instead of reduce

* Fix downloading to not download everything, fix Python2 syntax

* Fix end line lint

* Add random_seed, syntax fixes

* Remove randomized pair, take up as a separate generic wrapper

* Fix master conflict
parent 70440492
...@@ -5,4 +5,4 @@ torchvision.egg-info/ ...@@ -5,4 +5,4 @@ torchvision.egg-info/
*/**/*.pyc */**/*.pyc
*/**/*~ */**/*~
*~ *~
docs/build docs/build
\ No newline at end of file
...@@ -8,9 +8,11 @@ from .svhn import SVHN ...@@ -8,9 +8,11 @@ from .svhn import SVHN
from .phototour import PhotoTour from .phototour import PhotoTour
from .fakedata import FakeData from .fakedata import FakeData
from .semeion import SEMEION from .semeion import SEMEION
from .omniglot import Omniglot
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData', 'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION') 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot')
...@@ -2,7 +2,6 @@ from __future__ import print_function ...@@ -2,7 +2,6 @@ from __future__ import print_function
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
import errno
import numpy as np import numpy as np
import sys import sys
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
......
from __future__ import print_function
from PIL import Image
from os.path import join
import os
import torch.utils.data as data
from .utils import download_url, check_integrity, list_dir, list_files
class Omniglot(data.Dataset):
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``omniglot-py`` exists.
background (bool, optional): If True, creates dataset from the "background" set, otherwise
creates from the "evaluation" set. This terminology is defined by the authors.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset zip files from the internet and
puts it in root directory. If the zip files are already downloaded, they are not
downloaded again.
"""
folder = 'omniglot-py'
download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python'
zips_md5 = {
'images_background': '68d2efa1b9178cc56df9314c21c6e718',
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
}
def __init__(self, root, background=True,
transform=None, target_transform=None,
download=False):
self.root = join(os.path.expanduser(root), self.folder)
self.background = background
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder)
self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
for a in self._alphabets], [])
self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
for idx, character in enumerate(self._characters)]
self._flat_character_images = sum(self._character_images, [])
def __len__(self):
return len(self._flat_character_images)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target character class.
"""
image_name, character_class = self._flat_character_images[index]
image_path = join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode='r').convert('L')
if self.transform:
image = self.transform(image)
if self.target_transform:
character_class = self.target_transform(character_class)
return image, character_class
def _check_integrity(self):
zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
return False
return True
def download(self):
import zipfile
if self._check_integrity():
print('Files already downloaded and verified')
return
filename = self._get_target_folder()
zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename
download_url(url, self.root, zip_filename, self.zips_md5[filename])
print('Extracting downloaded file: ' + join(self.root, zip_filename))
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file:
zip_file.extractall(self.root)
def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation'
...@@ -45,3 +45,49 @@ def download_url(url, root, filename, md5): ...@@ -45,3 +45,49 @@ def download_url(url, root, filename, md5):
print('Failed download. Trying https -> http instead.' print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath) ' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(url, fpath) urllib.request.urlretrieve(url, fpath)
def list_dir(root, prefix=False):
"""List all directories at a given root
Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)
if prefix is True:
directories = [os.path.join(root, d) for d in directories]
return directories
def list_files(root, suffix, prefix=False):
"""List all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)
if prefix is True:
files = [os.path.join(root, d) for d in files]
return files
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment