Unverified Commit b522aca2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

extend equalize to all integer and floating dtypes (#6851)

* extend equalize to all integer and floating dtypes

* address nits
parent 52b80c48
...@@ -1322,7 +1322,7 @@ KERNEL_INFOS.extend( ...@@ -1322,7 +1322,7 @@ KERNEL_INFOS.extend(
def sample_inputs_equalize_image_tensor(): def sample_inputs_equalize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8] sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1331,27 +1331,41 @@ def reference_inputs_equalize_image_tensor(): ...@@ -1331,27 +1331,41 @@ def reference_inputs_equalize_image_tensor():
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range. # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one, # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# the information gain is low if we already provide something really close to the expected value. # the information gain is low if we already provide something really close to the expected value.
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor):
if dtype.is_floating_point:
low = low_factor
high = high_factor
else:
max_value = torch.iinfo(dtype).max
low = int(low_factor * max_value)
high = int(high_factor * max_value)
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta):
image = torch.distributions.Beta(alpha, beta).sample(shape)
if not dtype.is_floating_point:
image.mul_(torch.iinfo(dtype).max).round_()
return image.to(dtype=dtype, device=device)
spatial_size = (256, 256) spatial_size = (256, 256)
for fn, color_space in itertools.product( for dtype, color_space, fn in itertools.product(
[torch.uint8, torch.float32],
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
[ [
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
lambda shape, dtype, device: torch.full(
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
),
*[ *[
lambda shape, dtype, device, low=low, high=high: torch.randint( functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
low, high, shape, dtype=dtype, device=device for low_factor, high_factor in [
) (0.0, 0.25),
for low, high in [ (0.25, 0.75),
(0, 1), (0.75, 1.0),
(255, 256),
(0, 64),
(64, 192),
(192, 256),
] ]
], ],
*[ *[
lambda shape, dtype, device, alpha=alpha, beta=beta: torch.distributions.Beta(alpha, beta) functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta)
.sample(shape)
.mul_(255)
.round_()
.to(dtype=dtype, device=device)
for alpha, beta in [ for alpha, beta in [
(0.5, 0.5), (0.5, 0.5),
(2, 2), (2, 2),
...@@ -1360,10 +1374,9 @@ def reference_inputs_equalize_image_tensor(): ...@@ -1360,10 +1374,9 @@ def reference_inputs_equalize_image_tensor():
] ]
], ],
], ],
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
): ):
image_loader = ImageLoader( image_loader = ImageLoader(
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=torch.uint8, color_space=color_space fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype, color_space=color_space
) )
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
......
...@@ -371,26 +371,26 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: ...@@ -371,26 +371,26 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}")
num_channels, height, width = get_dimensions_image_tensor(image)
if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
if image.numel() == 0: if image.numel() == 0:
return image return image
# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
# unfeasible for `torch.int64`.
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base.
output_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.uint8)
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
# corresponds to adding 1 to index 127 in the histogram.
batch_shape = image.shape[:-2] batch_shape = image.shape[:-2]
flat_image = image.flatten(start_dim=-2).to(torch.long) flat_image = image.flatten(start_dim=-2).to(torch.long)
# The algorithm for histogram equalization is mirrored from PIL:
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385
# Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8
# images here and thus the values are already binned, the computation is trivial. The histogram is computed by using
# the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127
# in the histogram.
hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32) hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32)
hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image)) hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image))
cum_hist = hist.cumsum(dim=-1) cum_hist = hist.cumsum(dim=-1)
...@@ -398,6 +398,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -398,6 +398,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
# The simplest form of lookup-table (LUT) that also achieves histogram equalization is # The simplest form of lookup-table (LUT) that also achieves histogram equalization is
# `lut = cum_hist / flat_image.shape[-1] * 255` # `lut = cum_hist / flat_image.shape[-1] * 255`
# However, PIL uses a more elaborate scheme: # However, PIL uses a more elaborate scheme:
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385
# `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255` # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255`
# The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum
...@@ -415,7 +416,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -415,7 +416,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
# easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't, # easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't,
# we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to # we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to
# pay the runtime cost for checking it every time. # pay the runtime cost for checking it every time.
no_equalization = step.eq(0).unsqueeze_(-1) valid_equalization = step.ne(0).unsqueeze_(-1)
# `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the
# computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards. # computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards.
...@@ -434,7 +435,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -434,7 +435,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1) lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1)
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)
return torch.where(no_equalization, image, equalized_image) output = torch.where(valid_equalization, equalized_image, image)
return convert_dtype_image_tensor(output, output_dtype)
equalize_image_pil = _FP.equalize equalize_image_pil = _FP.equalize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment