Unverified Commit 9c4f7389 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixed issue with jitted AA transforms in v2 and added tests (#7839)

parent 37081ee6
...@@ -927,6 +927,29 @@ class TestAATransforms: ...@@ -927,6 +927,29 @@ class TestAATransforms:
assert_close(expected_output, output, atol=1, rtol=0.1) assert_close(expected_output, output, atol=1, rtol=0.1)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_randaug_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
torch.manual_seed(12)
expected_output = tt_ref(inpt)
torch.manual_seed(12)
scripted_output = tt(inpt)
assert_equal(scripted_output, expected_output)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt", "inpt",
[ [
...@@ -979,6 +1002,29 @@ class TestAATransforms: ...@@ -979,6 +1002,29 @@ class TestAATransforms:
assert_close(expected_output, output, atol=1, rtol=0.1) assert_close(expected_output, output, atol=1, rtol=0.1)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_trivial_aug_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
torch.manual_seed(12)
expected_output = tt_ref(inpt)
torch.manual_seed(12)
scripted_output = tt(inpt)
assert_equal(scripted_output, expected_output)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt", "inpt",
[ [
...@@ -1032,6 +1078,30 @@ class TestAATransforms: ...@@ -1032,6 +1078,30 @@ class TestAATransforms:
assert_equal(expected_output, output) assert_equal(expected_output, output)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_augmix_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
torch.manual_seed(12)
expected_output = tt_ref(inpt)
torch.manual_seed(12)
scripted_output = tt(inpt)
assert_equal(scripted_output, expected_output)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt", "inpt",
[ [
...@@ -1061,6 +1131,30 @@ class TestAATransforms: ...@@ -1061,6 +1131,30 @@ class TestAATransforms:
assert_equal(expected_output, output) assert_equal(expected_output, output)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_aa_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
torch.manual_seed(12)
expected_output = tt_ref(inpt)
torch.manual_seed(12)
scripted_output = tt(inpt)
assert_equal(scripted_output, expected_output)
def import_transforms_from_references(reference): def import_transforms_from_references(reference):
HERE = Path(__file__).parent HERE = Path(__file__).parent
......
...@@ -28,7 +28,16 @@ class _AutoAugmentBase(Transform): ...@@ -28,7 +28,16 @@ class _AutoAugmentBase(Transform):
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = _check_interpolation(interpolation) self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill) self.fill = fill
self._fill = _setup_fill_arg(fill)
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
return params
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
keys = tuple(dct.keys()) keys = tuple(dct.keys())
...@@ -335,7 +344,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -335,7 +344,7 @@ class AutoAugment(_AutoAugmentBase):
magnitude = 0.0 magnitude = 0.0
image_or_video = self._apply_image_or_video_transform( image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
) )
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
...@@ -419,7 +428,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -419,7 +428,7 @@ class RandAugment(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
image_or_video = self._apply_image_or_video_transform( image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
) )
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
...@@ -491,7 +500,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -491,7 +500,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
magnitude = 0.0 magnitude = 0.0
image_or_video = self._apply_image_or_video_transform( image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
) )
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
...@@ -614,7 +623,7 @@ class AugMix(_AutoAugmentBase): ...@@ -614,7 +623,7 @@ class AugMix(_AutoAugmentBase):
magnitude = 0.0 magnitude = 0.0
aug = self._apply_image_or_video_transform( aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
) )
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment