Unverified Commit 4cb83c2f authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixed RandAug and all AA consistency tests (#6519)

* [proto] Fixed RandAug implementation

* Fixed randomness in tests for trivial aug

* Fixed all AA tests
parent cc9ceb54
......@@ -1650,3 +1650,205 @@ class TestLabelToOneHot:
assert isinstance(ohe_labels, features.OneHotLabel)
assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories
class TestAPIConsistency:
@pytest.mark.parametrize("antialias", [True, False])
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
def test_random_resized_crop(self, antialias, inpt):
from torchvision.transforms import transforms as ref_transforms
size = 224
t_ref = ref_transforms.RandomResizedCrop(size, antialias=antialias)
t = transforms.RandomResizedCrop(size, antialias=antialias)
torch.manual_seed(12)
expected_output = t_ref(inpt)
torch.manual_seed(12)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_randaug(self, inpt, interpolation, mocker):
from torchvision.transforms import autoaugment as ref_transforms
t_ref = ref_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = transforms.RandAugment(interpolation=interpolation, num_ops=1)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
# Stable API, if signed there is another random call
if t._AUGMENTATION_SPACE[keys[i]][1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
for i in range(le):
expected_output = t_ref(inpt)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_trivial_aug(self, inpt, interpolation, mocker):
from torchvision.transforms import autoaugment as ref_transforms
t_ref = ref_transforms.TrivialAugmentWide(interpolation=interpolation)
t = transforms.TrivialAugmentWide(interpolation=interpolation)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
for _ in range(le):
expected_output = t_ref(inpt)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_augmix(self, inpt, interpolation, mocker):
from torchvision.transforms import autoaugment as ref_transforms
t_ref = ref_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
t = transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t._sample_dirichlet = lambda t: t.softmax(dim=-1)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
expected_output = t_ref(inpt)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_aa(self, inpt, interpolation):
from torchvision.transforms import autoaugment as ref_transforms
aa_policy = ref_transforms.AutoAugmentPolicy("imagenet")
t_ref = ref_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = transforms.AutoAugment(aa_policy, interpolation=interpolation)
torch.manual_seed(12)
expected_output = t_ref(inpt)
torch.manual_seed(12)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
......@@ -69,27 +69,46 @@ class _AutoAugmentBase(Transform):
interpolation: InterpolationMode,
fill: Union[int, float, Sequence[int], Sequence[float]],
) -> Any:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we have to put fill as None if fill == 0
fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]]
if isinstance(fill, int) and fill == 0:
fill_ = None
else:
fill_ = fill
if transform_id == "Identity":
return image
elif transform_id == "ShearX":
# magnitude should be arctan(magnitude)
# official autoaug: (1, level, 0, 0, 1, 0)
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
# compared to
# torchvision: (1, tan(level), 0, 0, 1, 0)
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
shear=[math.degrees(math.atan(magnitude)), 0.0],
interpolation=interpolation,
fill=fill,
fill=fill_,
center=[0, 0],
)
elif transform_id == "ShearY":
# magnitude should be arctan(magnitude)
# See above
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
shear=[0.0, math.degrees(math.atan(magnitude))],
interpolation=interpolation,
fill=fill,
fill=fill_,
center=[0, 0],
)
elif transform_id == "TranslateX":
return F.affine(
......@@ -99,7 +118,7 @@ class _AutoAugmentBase(Transform):
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
fill=fill_,
)
elif transform_id == "TranslateY":
return F.affine(
......@@ -109,10 +128,10 @@ class _AutoAugmentBase(Transform):
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
fill=fill_,
)
elif transform_id == "Rotate":
return F.rotate(image, angle=magnitude)
return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
elif transform_id == "Brightness":
return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
elif transform_id == "Color":
......@@ -340,19 +359,17 @@ class RandAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
_, height, width = get_chw(image)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
magnitude = float(magnitudes[self.magnitude])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
......@@ -397,7 +414,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
_, height, width = get_chw(image)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -467,7 +484,7 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample)
num_channels, height, width = get_chw(orig_image)
_, height, width = get_chw(orig_image)
if isinstance(orig_image, torch.Tensor):
image = orig_image
......
......@@ -379,8 +379,12 @@ def affine_segmentation_mask(
def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
fill = 0
return fill
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
......
......@@ -549,8 +549,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, dummy), dim=1)
mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, mask), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment