Commit 76c04d68 authored by Dmitry Nikulin's avatar Dmitry Nikulin Committed by Francisco Massa
Browse files

Support empty target_type for CelebA dataset (#1351)

* Support empty target_type for CelebA dataset

* Return (X, None) for interface consistency

* Document behavior for target_type=[]

* Simplify diff

* Raise exception on meaningless parameters
parent ef67fd92
......@@ -21,7 +21,7 @@ class CelebA(VisionDataset):
``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
Defaults to ``attr``.
Defaults to ``attr``. If empty, ``None`` will be returned as target.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
......@@ -59,6 +59,9 @@ class CelebA(VisionDataset):
else:
self.target_type = [target_type]
if not self.target_type and self.target_transform is not None:
raise RuntimeError('target_transform is specified but target_type is empty')
if download:
self.download()
......@@ -133,13 +136,17 @@ class CelebA(VisionDataset):
else:
# TODO: refactor with utils.verify_str_arg
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]
if self.transform is not None:
X = self.transform(X)
if self.target_transform is not None:
target = self.target_transform(target)
if target:
target = tuple(target) if len(target) > 1 else target[0]
if self.target_transform is not None:
target = self.target_transform(target)
else:
target = None
return X, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment