Unverified Commit d60d5e71 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

[CHERRYPICK] allow sequence fill for v2 AA scripted (#7920)

parent f588fd1a
...@@ -755,10 +755,11 @@ class TestAATransforms: ...@@ -755,10 +755,11 @@ class TestAATransforms:
v2_transforms.InterpolationMode.BILINEAR, v2_transforms.InterpolationMode.BILINEAR,
], ],
) )
def test_randaug_jit(self, interpolation): @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_randaug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1) t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
tt_ref = torch.jit.script(t_ref) tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t) tt = torch.jit.script(t)
...@@ -830,10 +831,11 @@ class TestAATransforms: ...@@ -830,10 +831,11 @@ class TestAATransforms:
v2_transforms.InterpolationMode.BILINEAR, v2_transforms.InterpolationMode.BILINEAR,
], ],
) )
def test_trivial_aug_jit(self, interpolation): @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_trivial_aug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation) t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
tt_ref = torch.jit.script(t_ref) tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t) tt = torch.jit.script(t)
...@@ -906,11 +908,12 @@ class TestAATransforms: ...@@ -906,11 +908,12 @@ class TestAATransforms:
v2_transforms.InterpolationMode.BILINEAR, v2_transforms.InterpolationMode.BILINEAR,
], ],
) )
def test_augmix_jit(self, interpolation): @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_augmix_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
tt_ref = torch.jit.script(t_ref) tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t) tt = torch.jit.script(t)
......
...@@ -33,8 +33,8 @@ class _AutoAugmentBase(Transform): ...@@ -33,8 +33,8 @@ class _AutoAugmentBase(Transform):
def _extract_params_for_v1_transform(self) -> Dict[str, Any]: def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform() params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))): if isinstance(params["fill"], dict):
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")
return params return params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment