Unverified Commit 70745705 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

call dataset wrapper with idx and sample (#7235)

parent b030e936
...@@ -74,7 +74,7 @@ class VisionDatasetDatapointWrapper(Dataset): ...@@ -74,7 +74,7 @@ class VisionDatasetDatapointWrapper(Dataset):
# of this class # of this class
sample = self._dataset[idx] sample = self._dataset[idx]
sample = self._wrapper(sample) sample = self._wrapper(idx, sample)
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`) # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
# or joint (`transforms`), we can access the full functionality through `transforms` # or joint (`transforms`), we can access the full functionality through `transforms`
...@@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): ...@@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
def classification_wrapper_factory(dataset): def classification_wrapper_factory(dataset):
return identity def wrapper(idx, sample):
return sample
return wrapper
for dataset_cls in [ for dataset_cls in [
...@@ -143,7 +146,7 @@ for dataset_cls in [ ...@@ -143,7 +146,7 @@ for dataset_cls in [
def segmentation_wrapper_factory(dataset): def segmentation_wrapper_factory(dataset):
def wrapper(sample): def wrapper(idx, sample):
image, mask = sample image, mask = sample
return image, pil_image_to_mask(mask) return image, pil_image_to_mask(mask)
...@@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset): ...@@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset):
f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead." f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
) )
def wrapper(sample): def wrapper(idx, sample):
video, audio, label = sample video, audio, label = sample
video = datapoints.Video(video) video = datapoints.Video(video)
...@@ -201,14 +204,17 @@ def coco_dectection_wrapper_factory(dataset): ...@@ -201,14 +204,17 @@ def coco_dectection_wrapper_factory(dataset):
) )
return torch.from_numpy(mask.decode(segmentation)) return torch.from_numpy(mask.decode(segmentation))
def wrapper(sample): def wrapper(idx, sample):
image_id = dataset.ids[idx]
image, target = sample image, target = sample
if not target:
return image, dict(image_id=image_id)
batched_target = list_of_dicts_to_dict_of_lists(target) batched_target = list_of_dicts_to_dict_of_lists(target)
image_ids = batched_target.pop("image_id") batched_target["image_id"] = image_id
image_id = batched_target["image_id"] = image_ids.pop()
assert all(other_image_id == image_id for other_image_id in image_ids)
spatial_size = tuple(F.get_spatial_size(image)) spatial_size = tuple(F.get_spatial_size(image))
batched_target["boxes"] = datapoints.BoundingBox( batched_target["boxes"] = datapoints.BoundingBox(
...@@ -259,7 +265,7 @@ VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC ...@@ -259,7 +265,7 @@ VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC
@WRAPPER_FACTORIES.register(datasets.VOCDetection) @WRAPPER_FACTORIES.register(datasets.VOCDetection)
def voc_detection_wrapper_factory(dataset): def voc_detection_wrapper_factory(dataset):
def wrapper(sample): def wrapper(idx, sample):
image, target = sample image, target = sample
batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
...@@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset): ...@@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset):
if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]): if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`") raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
def wrapper(sample): def wrapper(idx, sample):
image, target = sample image, target = sample
target = wrap_target_by_type( target = wrap_target_by_type(
...@@ -318,7 +324,7 @@ KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))) ...@@ -318,7 +324,7 @@ KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))
@WRAPPER_FACTORIES.register(datasets.Kitti) @WRAPPER_FACTORIES.register(datasets.Kitti)
def kitti_wrapper_factory(dataset): def kitti_wrapper_factory(dataset):
def wrapper(sample): def wrapper(idx, sample):
image, target = sample image, target = sample
if target is not None: if target is not None:
...@@ -336,7 +342,7 @@ def kitti_wrapper_factory(dataset): ...@@ -336,7 +342,7 @@ def kitti_wrapper_factory(dataset):
@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet) @WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
def oxford_iiit_pet_wrapper_factor(dataset): def oxford_iiit_pet_wrapper_factor(dataset):
def wrapper(sample): def wrapper(idx, sample):
image, target = sample image, target = sample
if target is not None: if target is not None:
...@@ -371,7 +377,7 @@ def cityscapes_wrapper_factory(dataset): ...@@ -371,7 +377,7 @@ def cityscapes_wrapper_factory(dataset):
labels.append(label) labels.append(label)
return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels)) return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))
def wrapper(sample): def wrapper(idx, sample):
image, target = sample image, target = sample
target = wrap_target_by_type( target = wrap_target_by_type(
...@@ -390,7 +396,7 @@ def cityscapes_wrapper_factory(dataset): ...@@ -390,7 +396,7 @@ def cityscapes_wrapper_factory(dataset):
@WRAPPER_FACTORIES.register(datasets.WIDERFace) @WRAPPER_FACTORIES.register(datasets.WIDERFace)
def widerface_wrapper(dataset): def widerface_wrapper(dataset):
def wrapper(sample): def wrapper(idx, sample):
image, target = sample image, target = sample
if target is not None: if target is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment