"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "bedd2a2088f7551d71d7dfb5788ed043790e9723"
Unverified Commit 997384cf authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port tests for RandomPhotometricDistort (#7973)

parent ace92213
...@@ -120,7 +120,6 @@ class TestSmoke: ...@@ -120,7 +120,6 @@ class TestSmoke:
(transforms.RandomEqualize(p=1.0), None), (transforms.RandomEqualize(p=1.0), None),
(transforms.RandomInvert(p=1.0), None), (transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None), (transforms.RandomChannelPermutation(), None),
(transforms.RandomPhotometricDistort(p=1.0), None),
(transforms.RandomPosterize(bits=4, p=1.0), None), (transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None), (transforms.RandomSolarize(threshold=0.5, p=1.0), None),
(transforms.CenterCrop([16, 16]), None), (transforms.CenterCrop([16, 16]), None),
......
...@@ -4040,3 +4040,28 @@ class TestRandomZoomOut: ...@@ -4040,3 +4040,28 @@ class TestRandomZoomOut:
assert 0 <= padding[1] <= (side_range[1] - 1) * height assert 0 <= padding[1] <= (side_range[1] - 1) * height
assert 0 <= padding[2] <= (side_range[1] - 1) * width assert 0 <= padding[2] <= (side_range[1] - 1) * width
assert 0 <= padding[3] <= (side_range[1] - 1) * height assert 0 <= padding[3] <= (side_range[1] - 1) * height
class TestRandomPhotometricDistort:
# Tests are light because this largely relies on the already tested
# `adjust_{brightness,contrast,saturation,hue}` and `permute_channels` kernels.
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
)
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, dtype, device):
if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"):
pytest.skip(
"PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
"will degenerate to that anyway."
)
check_transform(
transforms.RandomPhotometricDistort(
brightness=(0.3, 0.4), contrast=(0.5, 0.6), saturation=(0.7, 0.8), hue=(-0.1, 0.2), p=1
),
make_input(dtype=dtype, device=device),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment