Unverified Commit fd4a6939 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Add RandomEqualize prototype transforms (#5807)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 53c85328
......@@ -4,7 +4,7 @@ from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._color import ColorJitter, RandomPhotometricDistort
from ._color import ColorJitter, RandomPhotometricDistort, RandomEqualize
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
Resize,
......
......@@ -8,6 +8,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms import functional as _F
from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor, get_image_dimensions, query_image
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
......@@ -188,3 +189,19 @@ class RandomPhotometricDistort(Transform):
if params["channel_shuffle"]:
input = self._channel_shuffle(input)
return input
class RandomEqualize(_RandomApplyTransform):
def __init__(self, p: float = 0.5):
super().__init__(p=p)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.equalize_image_tensor(input)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return F.equalize_image_tensor(input)
elif isinstance(input, PIL.Image.Image):
return F.equalize_image_pil(input)
else:
return input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment