Unverified Commit cb4413a3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Fix hardcoded 255 (#6830)

* fix prototype kernels

* fix stable kernels

* fix tests

* make test more robust

* improve invert for signed integers

* improve invert

* fix posterize

* Revert "assume that integer images are [0, 255] in equalize (#6859)"

This reverts commit 436ff9a5

.

* fix solarize in AA

* fix resize

* Revert "fix resize"

This reverts commit 5f33f4aa73e098237650b794157ec9335d964152.

* add comment to float max value
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 4d085f2e
...@@ -790,32 +790,40 @@ def test_solarize2(device, dtype, config, channels): ...@@ -790,32 +790,40 @@ def test_solarize2(device, dtype, config, channels):
) )
@pytest.mark.parametrize(
("dtype", "threshold"),
[
*[
(dtype, threshold)
for dtype, threshold in itertools.product(
[torch.float32, torch.float16],
[0.0, 0.25, 0.5, 0.75, 1.0],
)
],
*[(torch.uint8, threshold) for threshold in [0, 64, 128, 192, 255]],
*[(torch.int64, threshold) for threshold in [0, 2**32, 2**63 - 1]],
],
)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0.0, 0.25, 0.5, 0.75, 1.0]) def test_solarize_threshold_within_bound(threshold, dtype, device):
def test_solarize_threshold1_bound(threshold, device): make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
img = torch.rand((3, 12, 23)).to(device) img = make_img((3, 12, 23), dtype=dtype, device=device)
F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [1.5])
def test_solarize_threshold1_upper_bound(threshold, device):
img = torch.rand((3, 12, 23)).to(device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0, 64, 128, 192, 255])
def test_solarize_threshold2_bound(threshold, device):
img = torch.randint(0, 256, (3, 12, 23)).to(device)
F_t.solarize(img, threshold) F_t.solarize(img, threshold)
@pytest.mark.parametrize(
("dtype", "threshold"),
[
(torch.float32, 1.5),
(torch.float16, 1.5),
(torch.uint8, 260),
(torch.int64, 2**64),
],
)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [260]) def test_solarize_threshold_above_bound(threshold, dtype, device):
def test_solarize_threshold2_upper_bound(threshold, device): make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
img = torch.randint(0, 256, (3, 12, 23)).to(device) img = make_img((3, 12, 23), dtype=dtype, device=device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."): with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold) F_t.solarize(img, threshold)
......
...@@ -8,6 +8,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec ...@@ -8,6 +8,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size from torchvision.prototype.transforms.functional._meta import get_spatial_size
from torchvision.transforms import functional_tensor as _FT
from ._utils import _isinstance, _setup_fill_arg from ._utils import _isinstance, _setup_fill_arg
...@@ -137,7 +138,7 @@ class _AutoAugmentBase(Transform): ...@@ -137,7 +138,7 @@ class _AutoAugmentBase(Transform):
elif transform_id == "Posterize": elif transform_id == "Posterize":
return F.posterize(image, bits=int(magnitude)) return F.posterize(image, bits=int(magnitude))
elif transform_id == "Solarize": elif transform_id == "Solarize":
bound = 1.0 if isinstance(image, torch.Tensor) and image.is_floating_point() else 255.0 bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
return F.solarize(image, threshold=bound * magnitude) return F.solarize(image, threshold=bound * magnitude)
elif transform_id == "AutoContrast": elif transform_id == "AutoContrast":
return F.autocontrast(image) return F.autocontrast(image)
......
...@@ -2,13 +2,13 @@ import torch ...@@ -2,13 +2,13 @@ import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from ._meta import _rgb_to_gray, convert_dtype_image_tensor from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio) ratio = float(ratio)
fp = image1.is_floating_point() fp = image1.is_floating_point()
bound = 1.0 if fp else 255.0 bound = _FT._max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound) output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype) return output if fp else output.to(image1.dtype)
...@@ -20,7 +20,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float ...@@ -20,7 +20,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
_FT._assert_channels(image, [1, 3]) _FT._assert_channels(image, [1, 3])
fp = image.is_floating_point() fp = image.is_floating_point()
bound = 1.0 if fp else 255.0 bound = _FT._max_value(image.dtype)
output = image.mul(brightness_factor).clamp_(0, bound) output = image.mul(brightness_factor).clamp_(0, bound)
return output if fp else output.to(image.dtype) return output if fp else output.to(image.dtype)
...@@ -222,8 +222,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -222,8 +222,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return image return image
orig_dtype = image.dtype orig_dtype = image.dtype
if image.dtype == torch.uint8: image = convert_dtype_image_tensor(image, torch.float32)
image = image / 255.0
image = _rgb_to_hsv(image) image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3) h, s, v = image.unbind(dim=-3)
...@@ -231,10 +230,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -231,10 +230,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
image = torch.stack((h, s, v), dim=-3) image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image) image_hue_adj = _hsv_to_rgb(image)
if orig_dtype == torch.uint8: return convert_dtype_image_tensor(image_hue_adj, orig_dtype)
image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype)
return image_hue_adj
adjust_hue_image_pil = _FP.adjust_hue adjust_hue_image_pil = _FP.adjust_hue
...@@ -289,14 +285,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> ...@@ -289,14 +285,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if bits > 8:
return image
if image.is_floating_point(): if image.is_floating_point():
levels = 1 << bits levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels) return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels)
else: else:
mask = ((1 << bits) - 1) << (8 - bits) num_value_bits = _num_value_bits(image.dtype)
if bits >= num_value_bits:
return image
mask = ((1 << bits) - 1) << (num_value_bits - bits)
return image & mask return image & mask
...@@ -317,8 +314,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: ...@@ -317,8 +314,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
bound = 1 if image.is_floating_point() else 255 if threshold > _FT._max_value(image.dtype):
if threshold > bound:
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
return torch.where(image >= threshold, invert_image_tensor(image), image) return torch.where(image >= threshold, invert_image_tensor(image), image)
...@@ -349,7 +345,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -349,7 +345,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
# exit earlier on empty images # exit earlier on empty images
return image return image
bound = 1.0 if image.is_floating_point() else 255.0 bound = _FT._max_value(image.dtype)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 dtype = image.dtype if torch.is_floating_point(image) else torch.float32
minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype) minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype)
...@@ -383,13 +379,17 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -383,13 +379,17 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0: if image.numel() == 0:
return image return image
output_dtype = image.dtype # 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
if image.is_floating_point(): # would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
# Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we # `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
# unfeasible for `torch.int64`.
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
# slower and more complicated to implement than a simple conversion and a fast histogram implementation for # and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
# integers. # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base.
output_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.uint8) image = convert_dtype_image_tensor(image, torch.uint8)
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
...@@ -461,10 +461,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: ...@@ -461,10 +461,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype == torch.uint8: if image.is_floating_point():
return 1.0 - image # type: ignore[no-any-return]
elif image.dtype == torch.uint8:
return image.bitwise_not() return image.bitwise_not()
else: else: # signed integer dtypes
return (1 if image.is_floating_point() else 255) - image # type: ignore[no-any-return] # We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign
return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1)
invert_image_pil = _FP.invert invert_image_pil = _FP.invert
......
...@@ -15,12 +15,6 @@ def _assert_image_tensor(img: Tensor) -> None: ...@@ -15,12 +15,6 @@ def _assert_image_tensor(img: Tensor) -> None:
raise TypeError("Tensor is not a torch image.") raise TypeError("Tensor is not a torch image.")
def _assert_threshold(img: Tensor, threshold: float) -> None:
bound = 1 if img.is_floating_point() else 255
if threshold > bound:
raise TypeError("Threshold should be less than bound of img.")
def get_dimensions(img: Tensor) -> List[int]: def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img) _assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3] channels = 1 if img.ndim == 2 else img.shape[-3]
...@@ -56,6 +50,8 @@ def _max_value(dtype: torch.dtype) -> int: ...@@ -56,6 +50,8 @@ def _max_value(dtype: torch.dtype) -> int:
elif dtype == torch.int64: elif dtype == torch.int64:
return 9223372036854775807 return 9223372036854775807
else: else:
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
# easy.
return 1 return 1
...@@ -212,8 +208,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -212,8 +208,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
return img return img
orig_dtype = img.dtype orig_dtype = img.dtype
if img.dtype == torch.uint8: img = convert_image_dtype(img, torch.float32)
img = img.to(dtype=torch.float32) / 255.0
img = _rgb2hsv(img) img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3) h, s, v = img.unbind(dim=-3)
...@@ -221,10 +216,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -221,10 +216,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
img = torch.stack((h, s, v), dim=-3) img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img) img_hue_adj = _hsv2rgb(img)
if orig_dtype == torch.uint8: return convert_image_dtype(img_hue_adj, orig_dtype)
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)
return img_hue_adj
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
...@@ -263,7 +255,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: ...@@ -263,7 +255,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio) ratio = float(ratio)
bound = 1.0 if img1.is_floating_point() else 255.0 bound = _max_value(img1.dtype)
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
...@@ -775,8 +767,7 @@ def invert(img: Tensor) -> Tensor: ...@@ -775,8 +767,7 @@ def invert(img: Tensor) -> Tensor:
_assert_channels(img, [1, 3]) _assert_channels(img, [1, 3])
bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device) return _max_value(img.dtype) - img
return bound - img
def posterize(img: Tensor, bits: int) -> Tensor: def posterize(img: Tensor, bits: int) -> Tensor:
...@@ -802,7 +793,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor: ...@@ -802,7 +793,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_channels(img, [1, 3]) _assert_channels(img, [1, 3])
_assert_threshold(img, threshold) if threshold > _max_value(img.dtype):
raise TypeError("Threshold should be less than bound of img.")
inverted_img = invert(img) inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img) return torch.where(img >= threshold, inverted_img, img)
...@@ -849,7 +841,7 @@ def autocontrast(img: Tensor) -> Tensor: ...@@ -849,7 +841,7 @@ def autocontrast(img: Tensor) -> Tensor:
_assert_channels(img, [1, 3]) _assert_channels(img, [1, 3])
bound = 1.0 if img.is_floating_point() else 255.0 bound = _max_value(img.dtype)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32 dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment