Unverified Commit 0aed8329 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add tests for transform presets, and various fixes (#7223)

parent c73411a4
import itertools import itertools
import re import re
from collections import defaultdict
import numpy as np import numpy as np
...@@ -1988,3 +1989,154 @@ class TestUniformTemporalSubsample: ...@@ -1988,3 +1989,154 @@ class TestUniformTemporalSubsample:
assert type(output) is type(inpt) assert type(output) is type(inpt)
assert output.shape[-4] == num_samples assert output.shape[-4] == num_samples
assert output.dtype == inpt.dtype assert output.dtype == inpt.dtype
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
if image_type is PIL.Image:
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)
label = 1 if label_type is int else torch.tensor([1])
if dataset_return_type is dict:
sample = {
"image": image,
"label": label,
}
else:
sample = image, label
t = transforms.Compose(
[
transforms.RandomResizedCrop((224, 224)),
transforms.RandomHorizontalFlip(p=1),
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AugMix(),
transforms.AutoAugment(),
to_tensor(),
# TODO: ConvertImageDtype is a pass-through on PIL images, is that
# intended? This results in a failure if we convert to tensor after
# it, because the image would still be uint8 which make Normalize
# fail.
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
transforms.RandomErasing(p=1),
]
)
out = t(sample)
assert type(out) == type(sample)
if dataset_return_type is tuple:
out_image, out_label = out
else:
assert out.keys() == sample.keys()
out_image, out_label = out.values()
assert out_image.shape[-2:] == (224, 224)
assert out_label == label
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, list))
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
if data_augmentation == "hflip":
t = [
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "lsj":
t = [
transforms.ScaleJitter(target_size=(1024, 1024), antialias=True),
# Note: replaced FixedSizeCrop with RandomCrop, becuase we're
# leaving FixedSizeCrop in prototype for now, and it expects Label
# classes which we won't release yet.
# transforms.FixedSizeCrop(
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})
# ),
transforms.RandomCrop((1024, 1024), pad_if_needed=True),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "multiscale":
t = [
transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True
),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "ssd":
t = [
transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})),
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "ssdlite":
t = [
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
t = transforms.Compose(t)
num_boxes = 5
H = W = 250
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8))
if image_type is PIL.Image:
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)
label = torch.randint(0, 10, size=(num_boxes,))
if label_type is list:
label = label.tolist()
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
boxes[:, 2:] += boxes[:, :2]
boxes = boxes.clamp(min=0, max=min(H, W))
boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W))
masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))
sample = {
"image": image,
"label": label,
"boxes": boxes,
"masks": masks,
}
out = t(sample)
if to_tensor is transforms.ToTensor and image_type is not datapoints.Image:
assert is_simple_tensor(out["image"])
else:
assert isinstance(out["image"], datapoints.Image)
assert isinstance(out["label"], type(sample["label"]))
out["label"] = torch.tensor(out["label"])
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes
...@@ -37,10 +37,11 @@ class _AutoAugmentBase(Transform): ...@@ -37,10 +37,11 @@ class _AutoAugmentBase(Transform):
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask), unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]: ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs)
image_or_videos = [] image_or_videos = []
for idx, inpt in enumerate(flat_inputs): for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
if check_type( if needs_transform and check_type(
inpt, inpt,
( (
datapoints.Image, datapoints.Image,
......
...@@ -169,7 +169,8 @@ class RandomPhotometricDistort(Transform): ...@@ -169,7 +169,8 @@ class RandomPhotometricDistort(Transform):
if isinstance(orig_inpt, PIL.Image.Image): if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt) inpt = F.pil_to_tensor(inpt)
output = inpt[..., permutation, :, :] # TODO: Find a better fix than as_subclass???
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
if isinstance(orig_inpt, PIL.Image.Image): if isinstance(orig_inpt, PIL.Image.Image):
output = F.to_image_pil(output) output = F.to_image_pil(output)
......
...@@ -36,8 +36,19 @@ class Transform(nn.Module): ...@@ -36,8 +36,19 @@ class Transform(nn.Module):
self._check_inputs(flat_inputs) self._check_inputs(flat_inputs)
params = self._get_params(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs)
params = self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
)
flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]
return tree_unflatten(flat_outputs, spec)
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with simple tensor inputs: # Below is a heuristic on how to deal with simple tensor inputs:
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
...@@ -53,7 +64,8 @@ class Transform(nn.Module): ...@@ -53,7 +64,8 @@ class Transform(nn.Module):
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone # The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern. # However, this case wasn't supported by transforms v1 either, so there is no BC concern.
flat_outputs = []
needs_transform_list = []
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
for inpt in flat_inputs: for inpt in flat_inputs:
needs_transform = True needs_transform = True
...@@ -65,10 +77,8 @@ class Transform(nn.Module): ...@@ -65,10 +77,8 @@ class Transform(nn.Module):
transform_simple_tensor = False transform_simple_tensor = False
else: else:
needs_transform = False needs_transform = False
needs_transform_list.append(needs_transform)
flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt) return needs_transform_list
return tree_unflatten(flat_outputs, spec)
def extra_repr(self) -> str: def extra_repr(self) -> str:
extra = [] extra = []
...@@ -159,10 +169,14 @@ class _RandomApplyTransform(Transform): ...@@ -159,10 +169,14 @@ class _RandomApplyTransform(Transform):
if torch.rand(1) >= self.p: if torch.rand(1) >= self.p:
return inputs return inputs
params = self._get_params(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs)
params = self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
)
flat_outputs = [ flat_outputs = [
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
] ]
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment