Unverified Commit 3f9b2d9c authored by vfdev's avatar vfdev Committed by GitHub
Browse files

WIP on adding gray images support for adjust_contrast (#4477)

parent cdb6fba5
...@@ -128,7 +128,12 @@ def needs_cuda(test_func): ...@@ -128,7 +128,12 @@ def needs_cuda(test_func):
def _create_data(height=3, width=3, channels=3, device="cpu"): def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy()) data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
mode = "RGB"
if channels == 1:
mode = "L"
data = data[..., 0]
pil_img = Image.fromarray(data, mode=mode)
return tensor, pil_img return tensor, pil_img
......
...@@ -641,12 +641,14 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation): ...@@ -641,12 +641,14 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation):
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"): def check_functional_vs_PIL_vs_scripted(
fn, fn_pil, fn_t, config, device, dtype, channels=3, tol=2.0 + 1e-10, agg_method="max"
):
script_fn = torch.jit.script(fn) script_fn = torch.jit.script(fn)
torch.manual_seed(15) torch.manual_seed(15)
tensor, pil_img = _create_data(26, 34, device=device) tensor, pil_img = _create_data(26, 34, channels=channels, device=device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) batch_tensors = _create_data_batch(16, 18, num_samples=4, channels=channels, device=device)
if dtype is not None: if dtype is not None:
tensor = F.convert_image_dtype(tensor, dtype) tensor = F.convert_image_dtype(tensor, dtype)
...@@ -798,14 +800,16 @@ def test_equalize(device): ...@@ -798,14 +800,16 @@ def test_equalize(device):
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
def test_adjust_contrast(device, dtype, config): @pytest.mark.parametrize('channels', [1, 3])
def test_adjust_contrast(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.adjust_contrast, F.adjust_contrast,
F_pil.adjust_contrast, F_pil.adjust_contrast,
F_t.adjust_contrast, F_t.adjust_contrast,
config, config,
device, device,
dtype dtype,
channels=channels
) )
......
...@@ -169,10 +169,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -169,10 +169,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
_assert_channels(img, [3]) _assert_channels(img, [3, 1])
c = get_image_num_channels(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32 dtype = img.dtype if torch.is_floating_point(img) else torch.float32
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor) return _blend(img, mean, contrast_factor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment