Unverified Commit d85aa6d3 authored by Muhammed Abdullah's avatar Muhammed Abdullah Committed by GitHub
Browse files

Added LFW Dataset (#4255)



* Added LFW Dataset

* Added dataset to list in __init__.py

* Updated lfw.py
* Created a common superclass for people and pairs type datatsets
* corrected the .download() method

* Added docstrings and updated datasets.rst

* Wrote tests for LFWPeople and LFWPairs

* Resolved mypy error: Need type annotation for "data"

* Updated inject_fake_data method for LFWPeople

* Updated tests for LFW

* Updated LFW tests and minor changes in lfw.py

* Updated LFW
* Added functionality for 10-fold validation view
* Optimized the code so to replace repeated lines by method in super
  class
* Updated LFWPeople to get classes from all lfw-names.txt rather than
  just the classes fron trainset

* Updated lfw.py and tests
* Updated inject_fake_data method to create 10fold fake data
* Minor changes in docstring and extra_repr

* resolved py lint errors

* Added checksums for annotation files

* Minor changes in test

* Updated docstrings, defaults and minor changes in test

* Removed 'os.path.exists' check
Co-authored-by: default avatarABD-01 <myac931@gmai.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 4d928927
......@@ -147,6 +147,17 @@ KMNIST
.. autoclass:: KMNIST
LFW
~~~~~
.. autoclass:: LFWPeople
:members: __getitem__
:special-members:
.. autoclass:: LFWPairs
:members: __getitem__
:special-members:
LSUN
~~~~
......
......@@ -1801,5 +1801,87 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
assert item[6] == i // 3
class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.LFWPeople
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=('10fold', 'train', 'test'),
image_set=('original', 'funneled', 'deepfunneled')
)
_IMAGES_DIR = {
"original": "lfw",
"funneled": "lfw_funneled",
"deepfunneled": "lfw-deepfunneled"
}
_file_id = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'}
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "lfw-py"
os.makedirs(tmpdir, exist_ok=True)
return dict(
num_examples=self._create_images_dir(tmpdir, self._IMAGES_DIR[config["image_set"]], config["split"]),
split=config["split"]
)
def _create_images_dir(self, root, idir, split):
idir = os.path.join(root, idir)
os.makedirs(idir, exist_ok=True)
n, flines = (10, ["10\n"]) if split == "10fold" else (1, [])
num_examples = 0
names = []
for _ in range(n):
num_people = random.randint(2, 5)
flines.append(f"{num_people}\n")
for i in range(num_people):
name = self._create_random_id()
no = random.randint(1, 10)
flines.append(f"{name}\t{no}\n")
names.append(f"{name}\t{no}\n")
datasets_utils.create_image_folder(idir, name, lambda n: f"{name}_{n+1:04d}.jpg", no, 250)
num_examples += no
with open(pathlib.Path(root) / f"people{self._file_id[split]}.txt", "w") as f:
f.writelines(flines)
with open(pathlib.Path(root) / "lfw-names.txt", "w") as f:
f.writelines(sorted(names))
return num_examples
def _create_random_id(self):
part1 = datasets_utils.create_random_string(random.randint(5, 7))
part2 = datasets_utils.create_random_string(random.randint(4, 7))
return f"{part1}_{part2}"
class LFWPairsTestCase(LFWPeopleTestCase):
DATASET_CLASS = datasets.LFWPairs
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, int)
def _create_images_dir(self, root, idir, split):
idir = os.path.join(root, idir)
os.makedirs(idir, exist_ok=True)
num_pairs = 7 # effectively 7*2*n = 14*n
n, self.flines = (10, [f"10\t{num_pairs}"]) if split == "10fold" else (1, [str(num_pairs)])
for _ in range(n):
self._inject_pairs(idir, num_pairs, True)
self._inject_pairs(idir, num_pairs, False)
with open(pathlib.Path(root) / f"pairs{self._file_id[split]}.txt", "w") as f:
f.writelines(self.flines)
return num_pairs * 2 * n
def _inject_pairs(self, root, num_pairs, same):
for i in range(num_pairs):
name1 = self._create_random_id()
name2 = name1 if same else self._create_random_id()
no1, no2 = random.randint(1, 100), random.randint(1, 100)
if same:
self.flines.append(f"\n{name1}\t{no1}\t{no2}")
else:
self.flines.append(f"\n{name1}\t{no1}\t{name2}\t{no2}")
datasets_utils.create_image_folder(root, name1, lambda _: f"{name1}_{no1:04d}.jpg", 1, 250)
datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250)
if __name__ == "__main__":
unittest.main()
......@@ -26,6 +26,7 @@ from .ucf101 import UCF101
from .places365 import Places365
from .kitti import Kitti
from .inaturalist import INaturalist
from .lfw import LFWPeople, LFWPairs
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
......@@ -36,5 +37,5 @@ __all__ = ('LSUN', 'LSUNClass',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101',
'Places365', 'Kitti', "INaturalist"
'Places365', 'Kitti', "INaturalist", "LFWPeople", "LFWPairs"
)
import os
from typing import Any, Callable, List, Optional, Tuple
from PIL import Image
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
class _LFW(VisionDataset):
base_folder = 'lfw-py'
download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"
file_dict = {
'original': ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
'funneled': ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
'deepfunneled': ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201")
}
checksums = {
'pairs.txt': '9f1ba174e4e1c508ff7cdf10ac338a7d',
'pairsDevTest.txt': '5132f7440eb68cf58910c8a45a2ac10b',
'pairsDevTrain.txt': '4f27cbf15b2da4a85c1907eb4181ad21',
'people.txt': '450f0863dd89e85e73936a6d71a3474b',
'peopleDevTest.txt': 'e4bf5be0a43b5dcd9dc5ccfcb8fb19c5',
'peopleDevTrain.txt': '54eaac34beb6d042ed3a7d883e247a21',
'lfw-names.txt': 'a6d0a479bd074669f656265a6e693f6d'
}
annot_file = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'}
names = "lfw-names.txt"
def __init__(
self,
root: str,
split: str,
image_set: str,
view: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(_LFW, self).__init__(os.path.join(root, self.base_folder),
transform=transform, target_transform=target_transform)
self.image_set = verify_str_arg(image_set.lower(), 'image_set', self.file_dict.keys())
images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
self.view = verify_str_arg(view.lower(), 'view', ['people', 'pairs'])
self.split = verify_str_arg(split.lower(), 'split', ['10fold', 'train', 'test'])
self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
self.data: List[Any] = []
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.images_dir = os.path.join(self.root, images_dir)
def _loader(self, path: str) -> Image.Image:
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def _check_integrity(self):
st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
if not st1 or not st2:
return False
if self.view == "people":
return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
return True
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
url = f"{self.download_url_prefix}{self.filename}"
download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
download_url(f"{self.download_url_prefix}{self.labels_file}", self.root)
if self.view == "people":
download_url(f"{self.download_url_prefix}{self.names}", self.root)
def _get_path(self, identity, no):
return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
def extra_repr(self) -> str:
return f"Alignment: {self.image_set}\nSplit: {self.split}"
def __len__(self):
return len(self.data)
class LFWPeople(_LFW):
"""`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``lfw-py`` exists or will be saved to if download is set to True.
split (string, optional): The image split to use. Can be one of ``train``, ``test``,
``10fold`` (default).
image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
``deepfunneled``. Defaults to ``funneled``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomRotation``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(
self,
root: str,
split: str = "10fold",
image_set: str = "funneled",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(LFWPeople, self).__init__(root, split, image_set, "people",
transform, target_transform, download)
self.class_to_idx = self._get_classes()
self.data, self.targets = self._get_people()
def _get_people(self):
data, targets = [], []
with open(os.path.join(self.root, self.labels_file), 'r') as f:
lines = f.readlines()
n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
for fold in range(n_folds):
n_lines = int(lines[s])
people = [line.strip().split("\t") for line in lines[s + 1: s + n_lines + 1]]
s += n_lines + 1
for i, (identity, num_imgs) in enumerate(people):
for num in range(1, int(num_imgs) + 1):
img = self._get_path(identity, num)
data.append(img)
targets.append(self.class_to_idx[identity])
return data, targets
def _get_classes(self):
with open(os.path.join(self.root, self.names), 'r') as f:
lines = f.readlines()
names = [line.strip().split()[0] for line in lines]
class_to_idx = {name: i for i, name in enumerate(names)}
return class_to_idx
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target) where target is the identity of the person.
"""
img = self._loader(self.data[index])
target = self.targets[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def extra_repr(self) -> str:
return super().extra_repr() + "\nClasses (identities): {}".format(len(self.class_to_idx))
class LFWPairs(_LFW):
"""`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``lfw-py`` exists or will be saved to if download is set to True.
split (string, optional): The image split to use. Can be one of ``train``, ``test``,
``10fold``. Defaults to ``10fold``.
image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
``deepfunneled``. Defaults to ``funneled``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomRotation``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(
self,
root: str,
split: str = "10fold",
image_set: str = "funneled",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(LFWPairs, self).__init__(root, split, image_set, "pairs",
transform, target_transform, download)
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
def _get_pairs(self, images_dir):
pair_names, data, targets = [], [], []
with open(os.path.join(self.root, self.labels_file), 'r') as f:
lines = f.readlines()
if self.split == "10fold":
n_folds, n_pairs = lines[0].split("\t")
n_folds, n_pairs = int(n_folds), int(n_pairs)
else:
n_folds, n_pairs = 1, int(lines[0])
s = 1
for fold in range(n_folds):
matched_pairs = [line.strip().split("\t") for line in lines[s: s + n_pairs]]
unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs: s + (2 * n_pairs)]]
s += (2 * n_pairs)
for pair in matched_pairs:
img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
pair_names.append((pair[0], pair[0]))
data.append((img1, img2))
targets.append(same)
for pair in unmatched_pairs:
img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[2], pair[3]), 0
pair_names.append((pair[0], pair[2]))
data.append((img1, img2))
targets.append(same)
return pair_names, data, targets
def __getitem__(self, index: int) -> Tuple[Any, Any, int]:
"""
Args:
index (int): Index
Returns:
tuple: (image1, image2, target) where target is `0` for different indentities and `1` for same identities.
"""
img1, img2 = self.data[index]
img1, img2 = self._loader(img1), self._loader(img2)
target = self.targets[index]
if self.transform is not None:
img1, img2 = self.transform(img1), self.transform(img2)
if self.target_transform is not None:
target = self.target_transform(target)
return img1, img2, target
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment