Unverified Commit ddfee23d authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port tests for container transforms (#8012)

parent 0040fe7a
...@@ -122,35 +122,6 @@ class TestTransform: ...@@ -122,35 +122,6 @@ class TestTransform:
t(inpt) t(inpt)
class TestContainers:
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
def test_assertions(self, transform_cls):
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
transform_cls(transforms.RandomCrop(28))
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
@pytest.mark.parametrize(
"trfms",
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
[transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)],
],
)
def test_ctor(self, transform_cls, trfms):
c = transform_cls(trfms)
inpt = torch.rand(1, 3, 32, 32)
output = c(inpt)
assert isinstance(output, torch.Tensor)
assert output.ndim == 4
class TestRandomChoice:
def test_assertions(self):
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1])
class TestRandomIoUCrop: class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
......
...@@ -11,9 +11,7 @@ import pytest ...@@ -11,9 +11,7 @@ import pytest
import torch import torch
import torchvision.transforms.v2 as v2_transforms import torchvision.transforms.v2 as v2_transforms
from common_utils import assert_close, assert_equal, set_rng_seed from common_utils import assert_close, assert_equal, set_rng_seed
from torch import nn
from torchvision import transforms as legacy_transforms, tv_tensors from torchvision import transforms as legacy_transforms, tv_tensors
from torchvision._utils import sequence_to_str
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F from torchvision.transforms.v2 import functional as prototype_F
...@@ -71,63 +69,7 @@ class ConsistencyConfig: ...@@ -71,63 +69,7 @@ class ConsistencyConfig:
LINEAR_TRANSFORMATION_MEAN = torch.rand(36) LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2) LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
CONSISTENCY_CONFIGS = [ CONSISTENCY_CONFIGS = []
ConsistencyConfig(
v2_transforms.Compose,
legacy_transforms.Compose,
),
ConsistencyConfig(
v2_transforms.RandomApply,
legacy_transforms.RandomApply,
),
ConsistencyConfig(
v2_transforms.RandomChoice,
legacy_transforms.RandomChoice,
),
ConsistencyConfig(
v2_transforms.RandomOrder,
legacy_transforms.RandomOrder,
),
]
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
prototype_params = dict(inspect.signature(config.prototype_cls).parameters)
for param in config.removed_params:
legacy_params.pop(param, None)
missing = legacy_params.keys() - prototype_params.keys()
if missing:
raise AssertionError(
f"The prototype transform does not support the parameters "
f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
f"the `ConsistencyConfig`."
)
extra = prototype_params.keys() - legacy_params.keys()
extra_without_default = {
param
for param in extra
if prototype_params[param].default is inspect.Parameter.empty
and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
if extra_without_default:
raise AssertionError(
f"The prototype transform requires the parameters "
f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
f"not. Please add a default value."
)
legacy_signature = list(legacy_params.keys())
# Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
# to the same number of parameters as the legacy one
prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]
assert prototype_signature == legacy_signature
def check_call_consistency( def check_call_consistency(
...@@ -288,84 +230,6 @@ def test_jit_consistency(config, args_kwargs): ...@@ -288,84 +230,6 @@ def test_jit_consistency(config, args_kwargs):
assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs) assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)
class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
consistency automatically tests the wrapped transforms consistency.
Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
that were already tested for consistency above.
"""
def test_compose(self):
prototype_transform = v2_transforms.Compose(
[
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
)
legacy_transform = legacy_transforms.Compose(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
prototype_transform = v2_transforms.RandomApply(
sequence_type(
[
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
),
p=p,
)
legacy_transform = legacy_transforms.RandomApply(
sequence_type(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
),
p=p,
)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
if sequence_type is nn.ModuleList:
# quick and dirty test that it is jit-scriptable
scripted = torch.jit.script(prototype_transform)
scripted(torch.rand(1, 3, 300, 300))
# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
def test_random_choice(self, probabilities):
prototype_transform = v2_transforms.RandomChoice(
[
v2_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=probabilities,
)
legacy_transform = legacy_transforms.RandomChoice(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=probabilities,
)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
class TestToTensorTransforms: class TestToTensorTransforms:
def test_pil_to_tensor(self): def test_pil_to_tensor(self):
prototype_transform = v2_transforms.PILToTensor() prototype_transform = v2_transforms.PILToTensor()
......
...@@ -396,6 +396,8 @@ def check_transform(transform, input, check_v1_compatibility=True, check_sample_ ...@@ -396,6 +396,8 @@ def check_transform(transform, input, check_v1_compatibility=True, check_sample_
if check_v1_compatibility: if check_v1_compatibility:
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility)) _check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))
return output
def transform_cls_to_functional(transform_cls, **transform_specific_kwargs): def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
def wrapper(input, *args, **kwargs): def wrapper(input, *args, **kwargs):
...@@ -1773,7 +1775,7 @@ class TestRotate: ...@@ -1773,7 +1775,7 @@ class TestRotate:
transforms.RandomAffine(degrees=0, fill="fill") transforms.RandomAffine(degrees=0, fill="fill")
class TestCompose: class TestContainerTransforms:
class BuiltinTransform(transforms.Transform): class BuiltinTransform(transforms.Transform):
def _transform(self, inpt, params): def _transform(self, inpt, params):
return inpt return inpt
...@@ -1788,7 +1790,10 @@ class TestCompose: ...@@ -1788,7 +1790,10 @@ class TestCompose:
return image, label return image, label
@pytest.mark.parametrize( @pytest.mark.parametrize(
"transform_clss", "transform_cls", [transforms.Compose, functools.partial(transforms.RandomApply, p=1), transforms.RandomOrder]
)
@pytest.mark.parametrize(
"wrapped_transform_clss",
[ [
[BuiltinTransform], [BuiltinTransform],
[PackedInputTransform], [PackedInputTransform],
...@@ -1803,12 +1808,12 @@ class TestCompose: ...@@ -1803,12 +1808,12 @@ class TestCompose:
], ],
) )
@pytest.mark.parametrize("unpack", [True, False]) @pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack): def test_packed_unpacked(self, transform_cls, wrapped_transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss) needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in wrapped_transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in wrapped_transform_clss)
assert not (needs_packed_inputs and needs_unpacked_inputs) assert not (needs_packed_inputs and needs_unpacked_inputs)
transform = transforms.Compose([cls() for cls in transform_clss]) transform = transform_cls([cls() for cls in wrapped_transform_clss])
image = make_image() image = make_image()
label = 3 label = 3
...@@ -1833,6 +1838,97 @@ class TestCompose: ...@@ -1833,6 +1838,97 @@ class TestCompose:
assert output[0] is image assert output[0] is image
assert output[1] is label assert output[1] is label
def test_compose(self):
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
)
input = make_image()
actual = check_transform(transform, input)
expected = F.vertical_flip(F.horizontal_flip(input))
assert_equal(actual, expected)
@pytest.mark.parametrize("p", [0.0, 1.0])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
transform = transforms.RandomApply(
sequence_type(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
),
p=p,
)
# This needs to be a pure tensor (or a PIL image), because otherwise check_transforms skips the v1 compatibility
# check
input = make_image_tensor()
output = check_transform(transform, input, check_v1_compatibility=issubclass(sequence_type, nn.ModuleList))
if p == 1:
assert_equal(output, F.vertical_flip(F.horizontal_flip(input)))
else:
assert output is input
@pytest.mark.parametrize("p", [(0, 1), (1, 0)])
def test_random_choice(self, p):
transform = transforms.RandomChoice(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
],
p=p,
)
input = make_image()
output = check_transform(transform, input)
p_horz, p_vert = p
if p_horz:
assert_equal(output, F.horizontal_flip(input))
else:
assert_equal(output, F.vertical_flip(input))
def test_random_order(self):
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
]
)
input = make_image()
actual = check_transform(transform, input)
# We can't really check whether the transforms are actually applied in random order. However, horizontal and
# vertical flip are commutative. Meaning, even under the assumption that the transform applies them in random
# order, we can use a fixed order to compute the expected value.
expected = F.vertical_flip(F.horizontal_flip(input))
assert_equal(actual, expected)
def test_errors(self):
for cls in [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]:
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
cls(lambda x: x)
with pytest.raises(ValueError, match="at least one transform"):
transforms.Compose([])
for p in [-1, 2]:
with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")):
transforms.RandomApply([lambda x: x], p=p)
for transforms_, p in [([lambda x: x], []), ([], [1.0])]:
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
transforms.RandomChoice(transforms_, p=p)
class TestToDtype: class TestToDtype:
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -100,14 +100,15 @@ class RandomApply(Transform): ...@@ -100,14 +100,15 @@ class RandomApply(Transform):
return {"transforms": self.transforms, "p": self.p} return {"transforms": self.transforms, "p": self.p}
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] needs_unpacking = len(inputs) > 1
if torch.rand(1) >= self.p: if torch.rand(1) >= self.p:
return sample return inputs if needs_unpacking else inputs[0]
for transform in self.transforms: for transform in self.transforms:
sample = transform(sample) outputs = transform(*inputs)
return sample inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str: def extra_repr(self) -> str:
format_string = [] format_string = []
...@@ -173,8 +174,9 @@ class RandomOrder(Transform): ...@@ -173,8 +174,9 @@ class RandomOrder(Transform):
self.transforms = transforms self.transforms = transforms
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] needs_unpacking = len(inputs) > 1
for idx in torch.randperm(len(self.transforms)): for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx] transform = self.transforms[idx]
sample = transform(sample) outputs = transform(*inputs)
return sample inputs = outputs if needs_unpacking else (outputs,)
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment