Unverified Commit 3f9b2d9c authored by vfdev's avatar vfdev Committed by GitHub
Browse files

WIP on adding gray images support for adjust_contrast (#4477)

parent cdb6fba5
......@@ -128,7 +128,12 @@ def needs_cuda(test_func):
def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
mode = "RGB"
if channels == 1:
mode = "L"
data = data[..., 0]
pil_img = Image.fromarray(data, mode=mode)
return tensor, pil_img
......
......@@ -641,12 +641,14 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation):
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):
def check_functional_vs_PIL_vs_scripted(
fn, fn_pil, fn_t, config, device, dtype, channels=3, tol=2.0 + 1e-10, agg_method="max"
):
script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = _create_data(26, 34, device=device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
tensor, pil_img = _create_data(26, 34, channels=channels, device=device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, channels=channels, device=device)
if dtype is not None:
tensor = F.convert_image_dtype(tensor, dtype)
......@@ -798,14 +800,16 @@ def test_equalize(device):
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
def test_adjust_contrast(device, dtype, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_adjust_contrast(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted(
F.adjust_contrast,
F_pil.adjust_contrast,
F_t.adjust_contrast,
config,
device,
dtype
dtype,
channels=channels
)
......
......@@ -169,10 +169,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
_assert_image_tensor(img)
_assert_channels(img, [3])
_assert_channels(img, [3, 1])
c = get_image_num_channels(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment