"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "58237364b1780223f48a80256f56408efe7b59a0"
Unverified Commit b6574c92 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port tests for transforms.ColorJitter (#7968)

parent 5fa8050d
...@@ -228,23 +228,6 @@ CONSISTENCY_CONFIGS = [ ...@@ -228,23 +228,6 @@ CONSISTENCY_CONFIGS = [
# Use default tolerances of `torch.testing.assert_close` # Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None), closeness_kwargs=dict(rtol=None, atol=None),
), ),
ConsistencyConfig(
v2_transforms.ColorJitter,
legacy_transforms.ColorJitter,
[
ArgsKwargs(),
ArgsKwargs(brightness=0.1),
ArgsKwargs(brightness=(0.2, 0.3)),
ArgsKwargs(contrast=0.4),
ArgsKwargs(contrast=(0.5, 0.6)),
ArgsKwargs(saturation=0.7),
ArgsKwargs(saturation=(0.8, 0.9)),
ArgsKwargs(hue=0.3),
ArgsKwargs(hue=(-0.1, 0.2)),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
),
ConsistencyConfig( ConsistencyConfig(
v2_transforms.PILToTensor, v2_transforms.PILToTensor,
legacy_transforms.PILToTensor, legacy_transforms.PILToTensor,
...@@ -453,49 +436,6 @@ def test_call_consistency(config, args_kwargs): ...@@ -453,49 +436,6 @@ def test_call_consistency(config, args_kwargs):
) )
get_params_parametrization = pytest.mark.parametrize(
("config", "get_params_args_kwargs"),
[
pytest.param(
next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
get_params_args_kwargs,
id=transform_cls.__name__,
)
for transform_cls, get_params_args_kwargs in [
(v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(v2_transforms.AutoAugment, ArgsKwargs(5)),
]
],
)
@get_params_parametrization
def test_get_params_alias(config, get_params_args_kwargs):
assert config.prototype_cls.get_params is config.legacy_cls.get_params
if not config.args_kwargs:
return
args, kwargs = config.args_kwargs[0]
legacy_transform = config.legacy_cls(*args, **kwargs)
prototype_transform = config.prototype_cls(*args, **kwargs)
assert prototype_transform.get_params is legacy_transform.get_params
@get_params_parametrization
def test_get_params_jit(config, get_params_args_kwargs):
get_params_args, get_params_kwargs = get_params_args_kwargs
torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)
if not config.args_kwargs:
return
args, kwargs = config.args_kwargs[0]
transform = config.prototype_cls(*args, **kwargs)
torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("config", "args_kwargs"), ("config", "args_kwargs"),
[ [
......
...@@ -3881,3 +3881,67 @@ class TestPerspective: ...@@ -3881,3 +3881,67 @@ class TestPerspective:
) )
assert_close(actual, expected, rtol=0, atol=1) assert_close(actual, expected, rtol=0, atol=1)
class TestColorJitter:
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
)
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, dtype, device):
if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"):
pytest.skip(
"PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
"will degenerate to that anyway."
)
check_transform(
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25),
make_input(dtype=dtype, device=device),
)
def test_transform_noop(self):
input = make_image()
input_version = input._version
transform = transforms.ColorJitter()
output = transform(input)
assert output is input
assert output.data_ptr() == input.data_ptr()
assert output._version == input_version
def test_transform_error(self):
with pytest.raises(ValueError, match="must be non negative"):
transforms.ColorJitter(brightness=-1)
for brightness in [object(), [1, 2, 3]]:
with pytest.raises(TypeError, match="single number or a sequence with length 2"):
transforms.ColorJitter(brightness=brightness)
with pytest.raises(ValueError, match="values should be between"):
transforms.ColorJitter(brightness=(-1, 0.5))
with pytest.raises(ValueError, match="values should be between"):
transforms.ColorJitter(hue=1)
@pytest.mark.parametrize("brightness", [None, 0.1, (0.2, 0.3)])
@pytest.mark.parametrize("contrast", [None, 0.4, (0.5, 0.6)])
@pytest.mark.parametrize("saturation", [None, 0.7, (0.8, 0.9)])
@pytest.mark.parametrize("hue", [None, 0.3, (-0.1, 0.2)])
def test_transform_correctness(self, brightness, contrast, saturation, hue):
image = make_image(dtype=torch.uint8, device="cpu")
transform = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
with freeze_rng_state():
torch.manual_seed(0)
actual = transform(image)
torch.manual_seed(0)
expected = F.to_image(transform(F.to_pil_image(image)))
mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment