"tests/vscode:/vscode.git/clone" did not exist on "57a22e7306db24140cb2133aa12a613cbf971c4c"
Unverified Commit cb4413a3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Fix hardcoded 255 (#6830)

* fix prototype kernels

* fix stable kernels

* fix tests

* make test more robust

* improve invert for signed integers

* improve invert

* fix posterize

* Revert "assume that integer images are [0, 255] in equalize (#6859)"

This reverts commit 436ff9a5

.

* fix solarize in AA

* fix resize

* Revert "fix resize"

This reverts commit 5f33f4aa73e098237650b794157ec9335d964152.

* add comment to float max value
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 4d085f2e
......@@ -790,32 +790,40 @@ def test_solarize2(device, dtype, config, channels):
)
@pytest.mark.parametrize(
("dtype", "threshold"),
[
*[
(dtype, threshold)
for dtype, threshold in itertools.product(
[torch.float32, torch.float16],
[0.0, 0.25, 0.5, 0.75, 1.0],
)
],
*[(torch.uint8, threshold) for threshold in [0, 64, 128, 192, 255]],
*[(torch.int64, threshold) for threshold in [0, 2**32, 2**63 - 1]],
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0.0, 0.25, 0.5, 0.75, 1.0])
def test_solarize_threshold1_bound(threshold, device):
img = torch.rand((3, 12, 23)).to(device)
F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [1.5])
def test_solarize_threshold1_upper_bound(threshold, device):
img = torch.rand((3, 12, 23)).to(device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0, 64, 128, 192, 255])
def test_solarize_threshold2_bound(threshold, device):
img = torch.randint(0, 256, (3, 12, 23)).to(device)
def test_solarize_threshold_within_bound(threshold, dtype, device):
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
img = make_img((3, 12, 23), dtype=dtype, device=device)
F_t.solarize(img, threshold)
@pytest.mark.parametrize(
("dtype", "threshold"),
[
(torch.float32, 1.5),
(torch.float16, 1.5),
(torch.uint8, 260),
(torch.int64, 2**64),
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [260])
def test_solarize_threshold2_upper_bound(threshold, device):
img = torch.randint(0, 256, (3, 12, 23)).to(device)
def test_solarize_threshold_above_bound(threshold, dtype, device):
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
img = make_img((3, 12, 23), dtype=dtype, device=device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold)
......
......@@ -8,6 +8,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size
from torchvision.transforms import functional_tensor as _FT
from ._utils import _isinstance, _setup_fill_arg
......@@ -137,7 +138,7 @@ class _AutoAugmentBase(Transform):
elif transform_id == "Posterize":
return F.posterize(image, bits=int(magnitude))
elif transform_id == "Solarize":
bound = 1.0 if isinstance(image, torch.Tensor) and image.is_floating_point() else 255.0
bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
return F.solarize(image, threshold=bound * magnitude)
elif transform_id == "AutoContrast":
return F.autocontrast(image)
......
......@@ -2,13 +2,13 @@ import torch
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from ._meta import _rgb_to_gray, convert_dtype_image_tensor
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
bound = 1.0 if fp else 255.0
bound = _FT._max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype)
......@@ -20,7 +20,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
_FT._assert_channels(image, [1, 3])
fp = image.is_floating_point()
bound = 1.0 if fp else 255.0
bound = _FT._max_value(image.dtype)
output = image.mul(brightness_factor).clamp_(0, bound)
return output if fp else output.to(image.dtype)
......@@ -222,8 +222,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return image
orig_dtype = image.dtype
if image.dtype == torch.uint8:
image = image / 255.0
image = convert_dtype_image_tensor(image, torch.float32)
image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3)
......@@ -231,10 +230,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image)
if orig_dtype == torch.uint8:
image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype)
return image_hue_adj
return convert_dtype_image_tensor(image_hue_adj, orig_dtype)
adjust_hue_image_pil = _FP.adjust_hue
......@@ -289,14 +285,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if bits > 8:
return image
if image.is_floating_point():
levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels)
else:
mask = ((1 << bits) - 1) << (8 - bits)
num_value_bits = _num_value_bits(image.dtype)
if bits >= num_value_bits:
return image
mask = ((1 << bits) - 1) << (num_value_bits - bits)
return image & mask
......@@ -317,8 +314,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
bound = 1 if image.is_floating_point() else 255
if threshold > bound:
if threshold > _FT._max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
return torch.where(image >= threshold, invert_image_tensor(image), image)
......@@ -349,7 +345,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
# exit earlier on empty images
return image
bound = 1.0 if image.is_floating_point() else 255.0
bound = _FT._max_value(image.dtype)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype)
......@@ -383,13 +379,17 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image
output_dtype = image.dtype
if image.is_floating_point():
# Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
# unfeasible for `torch.int64`.
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it
# slower and more complicated to implement than a simple conversion and a fast histogram implementation for
# integers.
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base.
output_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.uint8)
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
......@@ -461,10 +461,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype == torch.uint8:
if image.is_floating_point():
return 1.0 - image # type: ignore[no-any-return]
elif image.dtype == torch.uint8:
return image.bitwise_not()
else:
return (1 if image.is_floating_point() else 255) - image # type: ignore[no-any-return]
else: # signed integer dtypes
# We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign
return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1)
invert_image_pil = _FP.invert
......
......@@ -15,12 +15,6 @@ def _assert_image_tensor(img: Tensor) -> None:
raise TypeError("Tensor is not a torch image.")
def _assert_threshold(img: Tensor, threshold: float) -> None:
bound = 1 if img.is_floating_point() else 255
if threshold > bound:
raise TypeError("Threshold should be less than bound of img.")
def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
......@@ -56,6 +50,8 @@ def _max_value(dtype: torch.dtype) -> int:
elif dtype == torch.int64:
return 9223372036854775807
else:
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
# easy.
return 1
......@@ -212,8 +208,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
return img
orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0
img = convert_image_dtype(img, torch.float32)
img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
......@@ -221,10 +216,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)
if orig_dtype == torch.uint8:
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)
return img_hue_adj
return convert_image_dtype(img_hue_adj, orig_dtype)
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
......@@ -263,7 +255,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = 1.0 if img1.is_floating_point() else 255.0
bound = _max_value(img1.dtype)
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
......@@ -775,8 +767,7 @@ def invert(img: Tensor) -> Tensor:
_assert_channels(img, [1, 3])
bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
return bound - img
return _max_value(img.dtype) - img
def posterize(img: Tensor, bits: int) -> Tensor:
......@@ -802,7 +793,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_channels(img, [1, 3])
_assert_threshold(img, threshold)
if threshold > _max_value(img.dtype):
raise TypeError("Threshold should be less than bound of img.")
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
......@@ -849,7 +841,7 @@ def autocontrast(img: Tensor) -> Tensor:
_assert_channels(img, [1, 3])
bound = 1.0 if img.is_floating_point() else 255.0
bound = _max_value(img.dtype)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment