Unverified Commit c3aee873 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port tests for rgb_to_grayscale functional and transforms (#7967)

parent 9ba21583
...@@ -116,11 +116,9 @@ class TestSmoke: ...@@ -116,11 +116,9 @@ class TestSmoke:
(transforms.RandAugment(), auto_augment_adapter), (transforms.RandAugment(), auto_augment_adapter),
(transforms.TrivialAugmentWide(), auto_augment_adapter), (transforms.TrivialAugmentWide(), auto_augment_adapter),
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None), (transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
(transforms.Grayscale(), None),
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None), (transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
(transforms.RandomAutocontrast(p=1.0), None), (transforms.RandomAutocontrast(p=1.0), None),
(transforms.RandomEqualize(p=1.0), None), (transforms.RandomEqualize(p=1.0), None),
(transforms.RandomGrayscale(p=1.0), None),
(transforms.RandomInvert(p=1.0), None), (transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None), (transforms.RandomChannelPermutation(), None),
(transforms.RandomPhotometricDistort(p=1.0), None), (transforms.RandomPhotometricDistort(p=1.0), None),
......
...@@ -122,17 +122,6 @@ CONSISTENCY_CONFIGS = [ ...@@ -122,17 +122,6 @@ CONSISTENCY_CONFIGS = [
(torch.float32, torch.float64), (torch.float32, torch.float64),
] ]
], ],
ConsistencyConfig(
v2_transforms.Grayscale,
legacy_transforms.Grayscale,
[
ArgsKwargs(num_output_channels=1),
ArgsKwargs(num_output_channels=3),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig( ConsistencyConfig(
v2_transforms.ToPILImage, v2_transforms.ToPILImage,
legacy_transforms.ToPILImage, legacy_transforms.ToPILImage,
...@@ -217,17 +206,6 @@ CONSISTENCY_CONFIGS = [ ...@@ -217,17 +206,6 @@ CONSISTENCY_CONFIGS = [
], ],
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6}, closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
), ),
ConsistencyConfig(
v2_transforms.RandomGrayscale,
legacy_transforms.RandomGrayscale,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig( ConsistencyConfig(
v2_transforms.PILToTensor, v2_transforms.PILToTensor,
legacy_transforms.PILToTensor, legacy_transforms.PILToTensor,
......
...@@ -3945,3 +3945,58 @@ class TestColorJitter: ...@@ -3945,3 +3945,58 @@ class TestColorJitter:
mae = (actual.float() - expected.float()).abs().mean() mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2 assert mae < 2
class TestRgbToGrayscale:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image(self, dtype, device):
check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device))
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
def test_functional(self, make_input):
check_functional(F.rgb_to_grayscale, make_input())
@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.rgb_to_grayscale_image, torch.Tensor),
(F._rgb_to_grayscale_image_pil, PIL.Image.Image),
(F.rgb_to_grayscale_image, tv_tensors.Image),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)])
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
def test_transform(self, transform, make_input):
check_transform(transform, make_input())
@pytest.mark.parametrize("num_output_channels", [1, 3])
@pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
def test_image_correctness(self, num_output_channels, fn):
image = make_image(dtype=torch.uint8, device="cpu")
actual = fn(image, num_output_channels=num_output_channels)
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))
assert_equal(actual, expected, rtol=0, atol=1)
@pytest.mark.parametrize("num_input_channels", [1, 3])
def test_random_transform_correctness(self, num_input_channels):
image = make_image(
color_space={
1: "GRAY",
3: "RGB",
}[num_input_channels],
dtype=torch.uint8,
device="cpu",
)
transform = transforms.RandomGrayscale(p=1)
actual = transform(image)
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_input_channels))
assert_equal(actual, expected, rtol=0, atol=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment