"torchvision/vscode:/vscode.git/clone" did not exist on "8263c8a1dc0cd597a46a9c72e7da792e666e9890"
Unverified Commit 70745705 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

call dataset wrapper with idx and sample (#7235)

parent b030e936
......@@ -74,7 +74,7 @@ class VisionDatasetDatapointWrapper(Dataset):
# of this class
sample = self._dataset[idx]
sample = self._wrapper(sample)
sample = self._wrapper(idx, sample)
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
# or joint (`transforms`), we can access the full functionality through `transforms`
......@@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
def classification_wrapper_factory(dataset):
return identity
def wrapper(idx, sample):
return sample
return wrapper
for dataset_cls in [
......@@ -143,7 +146,7 @@ for dataset_cls in [
def segmentation_wrapper_factory(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, mask = sample
return image, pil_image_to_mask(mask)
......@@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset):
f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
)
def wrapper(sample):
def wrapper(idx, sample):
video, audio, label = sample
video = datapoints.Video(video)
......@@ -201,14 +204,17 @@ def coco_dectection_wrapper_factory(dataset):
)
return torch.from_numpy(mask.decode(segmentation))
def wrapper(sample):
def wrapper(idx, sample):
image_id = dataset.ids[idx]
image, target = sample
if not target:
return image, dict(image_id=image_id)
batched_target = list_of_dicts_to_dict_of_lists(target)
image_ids = batched_target.pop("image_id")
image_id = batched_target["image_id"] = image_ids.pop()
assert all(other_image_id == image_id for other_image_id in image_ids)
batched_target["image_id"] = image_id
spatial_size = tuple(F.get_spatial_size(image))
batched_target["boxes"] = datapoints.BoundingBox(
......@@ -259,7 +265,7 @@ VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC
@WRAPPER_FACTORIES.register(datasets.VOCDetection)
def voc_detection_wrapper_factory(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample
batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
......@@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset):
if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample
target = wrap_target_by_type(
......@@ -318,7 +324,7 @@ KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))
@WRAPPER_FACTORIES.register(datasets.Kitti)
def kitti_wrapper_factory(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample
if target is not None:
......@@ -336,7 +342,7 @@ def kitti_wrapper_factory(dataset):
@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
def oxford_iiit_pet_wrapper_factor(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample
if target is not None:
......@@ -371,7 +377,7 @@ def cityscapes_wrapper_factory(dataset):
labels.append(label)
return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample
target = wrap_target_by_type(
......@@ -390,7 +396,7 @@ def cityscapes_wrapper_factory(dataset):
@WRAPPER_FACTORIES.register(datasets.WIDERFace)
def widerface_wrapper(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample
if target is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment