Unverified Commit 900982fc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

extend support of posterize to all integer and floating dtypes (#6847)

* extend support of posterize to all integer and floating dtypes

* remove raise

* revert to fixed value range for integer dtypes
parent 6af796ab
......@@ -1446,16 +1446,14 @@ _POSTERIZE_BITS = [1, 4, 8]
def sample_inputs_posterize_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])
def reference_inputs_posterize_image_tensor():
for image_loader, bits in itertools.product(
make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
_POSTERIZE_BITS,
):
yield ArgsKwargs(image_loader, bits=bits)
......
......@@ -289,7 +289,18 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
posterize_image_tensor = _FT.posterize
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if bits > 8:
return image
if image.is_floating_point():
levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels)
else:
mask = ((1 << bits) - 1) << (8 - bits)
return image & mask
posterize_image_pil = _FP.posterize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment