Unverified Commit 95d41897 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype AugMix transform (#5492)

* add prototype AugMix transform

* cleanup

* refactor auto augment subclasses to only trnasform a single image

* address review comments
parent 7767f120
......@@ -114,6 +114,7 @@ class TestSmoke:
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AutoAugment(),
transforms.AugMix(),
)
]
)
......
......@@ -5,7 +5,7 @@ from . import functional # usort: skip
from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
......
......@@ -9,14 +9,16 @@ from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]:
def fn(
id: Tuple[Any, ...], input: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
return input
return id, input
return None
try:
return next(query_recursively(fn, sample))
return next(query_recursively(fn, sample))[1]
except StopIteration:
raise TypeError("No image was found in the sample")
......
......@@ -312,15 +312,18 @@ def apply_recursively(fn: Callable, obj: Any) -> Any:
return fn(obj)
def query_recursively(fn: Callable[[Any], Optional[D]], obj: Any) -> Iterator[D]:
def query_recursively(
fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = ()
) -> Iterator[D]:
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if (isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)) or isinstance(
obj, collections.abc.Mapping
):
for item in obj.values() if isinstance(obj, collections.abc.Mapping) else obj:
yield from query_recursively(fn, item)
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
for idx, item in enumerate(obj):
yield from query_recursively(fn, item, id=(*id, idx))
elif isinstance(obj, collections.abc.Mapping):
for key, item in obj.items():
yield from query_recursively(fn, item, id=(*id, key))
else:
result = fn(obj)
result = fn(id, obj)
if result is not None:
yield result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment