Unverified Commit 95d41897 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype AugMix transform (#5492)

* add prototype AugMix transform

* cleanup

* refactor auto augment subclasses to only trnasform a single image

* address review comments
parent 7767f120
...@@ -114,6 +114,7 @@ class TestSmoke: ...@@ -114,6 +114,7 @@ class TestSmoke:
transforms.RandAugment(), transforms.RandAugment(),
transforms.TrivialAugmentWide(), transforms.TrivialAugmentWide(),
transforms.AutoAugment(), transforms.AutoAugment(),
transforms.AugMix(),
) )
] ]
) )
......
...@@ -5,7 +5,7 @@ from . import functional # usort: skip ...@@ -5,7 +5,7 @@ from . import functional # usort: skip
from ._transform import Transform # usort: skip from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
......
...@@ -9,14 +9,16 @@ from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_ ...@@ -9,14 +9,16 @@ from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: def fn(
id: Tuple[Any, ...], input: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
return input return id, input
return None return None
try: try:
return next(query_recursively(fn, sample)) return next(query_recursively(fn, sample))[1]
except StopIteration: except StopIteration:
raise TypeError("No image was found in the sample") raise TypeError("No image was found in the sample")
......
...@@ -312,15 +312,18 @@ def apply_recursively(fn: Callable, obj: Any) -> Any: ...@@ -312,15 +312,18 @@ def apply_recursively(fn: Callable, obj: Any) -> Any:
return fn(obj) return fn(obj)
def query_recursively(fn: Callable[[Any], Optional[D]], obj: Any) -> Iterator[D]: def query_recursively(
fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = ()
) -> Iterator[D]:
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]... # "a" == "a"[0][0]...
if (isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)) or isinstance( if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
obj, collections.abc.Mapping for idx, item in enumerate(obj):
): yield from query_recursively(fn, item, id=(*id, idx))
for item in obj.values() if isinstance(obj, collections.abc.Mapping) else obj: elif isinstance(obj, collections.abc.Mapping):
yield from query_recursively(fn, item) for key, item in obj.items():
yield from query_recursively(fn, item, id=(*id, key))
else: else:
result = fn(obj) result = fn(id, obj)
if result is not None: if result is not None:
yield result yield result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment