Unverified Commit 48f8473e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port tests for type conversion transforms (#8003)

parent ee28bb3c
......@@ -122,61 +122,6 @@ class TestTransform:
t(inpt)
class TestToImage:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch(
"torchvision.transforms.v2.functional.to_image",
return_value=torch.rand(1, 3, 8, 8),
)
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImage()
transform(inpt)
if inpt_type in (tv_tensors.BoundingBoxes, tv_tensors.Image, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt)
class TestToPILImage:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image")
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage()
transform(inpt)
if inpt_type in (PIL.Image.Image, tv_tensors.BoundingBoxes, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, mode=transform.mode)
class TestToTensor:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_tensor")
inpt = mocker.MagicMock(spec=inpt_type)
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor()
transform(inpt)
if inpt_type in (tv_tensors.Image, torch.Tensor, tv_tensors.BoundingBoxes, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt)
class TestContainers:
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
def test_assertions(self, transform_cls):
......
......@@ -72,21 +72,6 @@ LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
CONSISTENCY_CONFIGS = [
ConsistencyConfig(
v2_transforms.ToPILImage,
legacy_transforms.ToPILImage,
[NotScriptableArgsKwargs()],
make_images_kwargs=dict(
color_spaces=[
"GRAY",
"GRAY_ALPHA",
"RGB",
"RGBA",
],
extra_dims=[()],
),
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.Lambda,
legacy_transforms.Lambda,
......@@ -97,14 +82,6 @@ CONSISTENCY_CONFIGS = [
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.PILToTensor,
legacy_transforms.PILToTensor,
),
ConsistencyConfig(
v2_transforms.ToTensor,
legacy_transforms.ToTensor,
),
ConsistencyConfig(
v2_transforms.Compose,
legacy_transforms.Compose,
......
......@@ -5047,3 +5047,82 @@ class TestLinearTransform:
ValueError, match="Input tensor should be on the same device as transformation matrix and mean vector"
):
transform(input)
def make_image_numpy(*args, **kwargs):
image = make_image_tensor(*args, **kwargs)
return image.permute((1, 2, 0)).numpy()
class TestToImage:
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy])
@pytest.mark.parametrize("fn", [F.to_image, transform_cls_to_functional(transforms.ToImage)])
def test_functional_and_transform(self, make_input, fn):
input = make_input()
output = fn(input)
assert isinstance(output, tv_tensors.Image)
input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
assert F.get_size(output) == input_size
if isinstance(input, torch.Tensor):
assert output.data_ptr() == input.data_ptr()
def test_functional_error(self):
with pytest.raises(TypeError, match="Input can either be a pure Tensor, a numpy array, or a PIL image"):
F.to_image(object())
class TestToPILImage:
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_numpy])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("fn", [F.to_pil_image, transform_cls_to_functional(transforms.ToPILImage)])
def test_functional_and_transform(self, make_input, color_space, fn):
input = make_input(color_space=color_space)
output = fn(input)
assert isinstance(output, PIL.Image.Image)
input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
assert F.get_size(output) == input_size
def test_functional_error(self):
with pytest.raises(TypeError, match="pic should be Tensor or ndarray"):
F.to_pil_image(object())
for ndim in [1, 4]:
with pytest.raises(ValueError, match="pic should be 2/3 dimensional"):
F.to_pil_image(torch.empty(*[1] * ndim))
with pytest.raises(ValueError, match="pic should not have > 4 channels"):
num_channels = 5
F.to_pil_image(torch.empty(num_channels, 1, 1))
class TestToTensor:
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy])
def test_smoke(self, make_input):
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor()
input = make_input()
output = transform(input)
input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
assert F.get_size(output) == input_size
class TestPILToTensor:
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("fn", [F.pil_to_tensor, transform_cls_to_functional(transforms.PILToTensor)])
def test_functional_and_transform(self, color_space, fn):
input = make_image_pil(color_space=color_space)
output = fn(input)
assert isinstance(output, torch.Tensor) and not isinstance(output, tv_tensors.TVTensor)
assert F.get_size(output) == F.get_size(input)
def test_functional_error(self):
with pytest.raises(TypeError, match="pic should be PIL Image"):
F.pil_to_tensor(object())
......@@ -17,7 +17,9 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso
elif isinstance(inpt, torch.Tensor):
output = inpt
else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
raise TypeError(
f"Input can either be a pure Tensor, a numpy array, or a PIL image, but got {type(inpt)} instead."
)
return tv_tensors.Image(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment