"git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "5ea40abf613e47bb56a0c06f695644d55671f585"
Unverified Commit 0c36735d authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

[RFC] Add support for joint transformations in VisionDataset (#872)

* [RFC] Add support for joint transformations in VisionDataset

* Add joints transforms for VOC and SBD

Breaking change in SBD, the xy_transform has been renamed transforms. I think this is fine given that we have not released a version of torchvision that contains it
parent b1fb79f9
......@@ -43,10 +43,8 @@ class CocoCaptions(VisionDataset):
"""
def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoCaptions, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
......@@ -68,11 +66,9 @@ class CocoCaptions(VisionDataset):
path = coco.loadImgs(img_id)[0]['file_name']
img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
......@@ -92,10 +88,8 @@ class CocoDetection(VisionDataset):
target and transforms it.
"""
def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoDetection, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
......@@ -116,11 +110,8 @@ class CocoDetection(VisionDataset):
path = coco.loadImgs(img_id)[0]['file_name']
img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
......
......@@ -54,7 +54,7 @@ class SBDataset(VisionDataset):
image_set='train',
mode='boundaries',
download=False,
xy_transform=None, **kwargs):
transforms=None):
try:
from scipy.io import loadmat
......@@ -63,12 +63,11 @@ class SBDataset(VisionDataset):
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
"pip install scipy")
super(SBDataset, self).__init__(root)
super(SBDataset, self).__init__(root, transforms)
if mode not in ("segmentation", "boundaries"):
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")
self.xy_transform = xy_transform
self.image_set = image_set
self.mode = mode
self.num_classes = 20
......@@ -120,8 +119,8 @@ class SBDataset(VisionDataset):
img = Image.open(self.images[index]).convert('RGB')
target = self._get_target(self.masks[index])
if self.xy_transform is not None:
img, target = self.xy_transform(img, target)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
......
......@@ -6,11 +6,25 @@ import torch.utils.data as data
class VisionDataset(data.Dataset):
_repr_indent = 4
def __init__(self, root):
def __init__(self, root, transforms=None, transform=None, target_transform=None):
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError("Only transforms or transform/target_transform can "
"be passed as argument")
# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform
if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms
def __getitem__(self, index):
raise NotImplementedError
......@@ -23,12 +37,8 @@ class VisionDataset(data.Dataset):
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if hasattr(self, 'transform') and self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transforms: ")
if hasattr(self, 'target_transform') and self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transforms: ")
if self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)
......@@ -39,3 +49,32 @@ class VisionDataset(data.Dataset):
def extra_repr(self):
return ""
class StandardTransform(object):
def __init__(self, transform=None, target_transform=None):
self.transform = transform
self.target_transform = target_transform
def __call__(self, input, target):
if self.transform is not None:
input = self.transform(input)
if self.target_transform is not None:
target = self.target_transform(target)
return input, target
def _format_transform_repr(self, transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def __repr__(self):
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transform: ")
return '\n'.join(body)
......@@ -74,10 +74,9 @@ class VOCSegmentation(VisionDataset):
image_set='train',
download=False,
transform=None,
target_transform=None):
super(VOCSegmentation, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
target_transform=None,
transforms=None):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
......@@ -122,11 +121,8 @@ class VOCSegmentation(VisionDataset):
img = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.masks[index])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
......@@ -157,10 +153,9 @@ class VOCDetection(VisionDataset):
image_set='train',
download=False,
transform=None,
target_transform=None):
super(VOCDetection, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
target_transform=None,
transforms=None):
super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
......@@ -208,11 +203,8 @@ class VOCDetection(VisionDataset):
target = self.parse_voc_xml(
ET.parse(self.annotations[index]).getroot())
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment