Unverified Commit 0040fe7a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port tests for transforms.Lambda (#8011)

parent b6189a8d
...@@ -574,38 +574,3 @@ def test_sanitize_bounding_boxes_errors(): ...@@ -574,38 +574,3 @@ def test_sanitize_bounding_boxes_errors():
with pytest.raises(ValueError, match="Number of boxes"): with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes) transforms.SanitizeBoundingBoxes()(different_sizes)
class TestLambda:
inputs = pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
@inputs
def test_default(self, input):
was_applied = False
def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input
transform = transforms.Lambda(was_applied_fn)
transform(input)
assert was_applied
@inputs
def test_with_types(self, input):
was_applied = False
def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input
types = (torch.Tensor, np.ndarray)
transform = transforms.Lambda(was_applied_fn, *types)
transform(input)
assert was_applied is isinstance(input, types)
...@@ -72,16 +72,6 @@ LINEAR_TRANSFORMATION_MEAN = torch.rand(36) ...@@ -72,16 +72,6 @@ LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2) LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
CONSISTENCY_CONFIGS = [ CONSISTENCY_CONFIGS = [
ConsistencyConfig(
v2_transforms.Lambda,
legacy_transforms.Lambda,
[
NotScriptableArgsKwargs(lambda image: image / 2),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig( ConsistencyConfig(
v2_transforms.Compose, v2_transforms.Compose,
legacy_transforms.Compose, legacy_transforms.Compose,
......
...@@ -5126,3 +5126,21 @@ class TestPILToTensor: ...@@ -5126,3 +5126,21 @@ class TestPILToTensor:
def test_functional_error(self): def test_functional_error(self):
with pytest.raises(TypeError, match="pic should be PIL Image"): with pytest.raises(TypeError, match="pic should be PIL Image"):
F.pil_to_tensor(object()) F.pil_to_tensor(object())
class TestLambda:
@pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
@pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])
def test_transform(self, input, types):
was_applied = False
def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input
transform = transforms.Lambda(was_applied_fn, *types)
output = transform(input)
assert output is input
assert was_applied is (not types or isinstance(input, types))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment