Unverified Commit 52e6bd08 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype transforms that use the prototype dispatchers (#5418)

* add prototype transforms that use the prototype dispatchers

Conflicts:
	torchvision/prototype/transforms/__init__.py

* simplify

* add logger

* remove legacy classes

Conflicts:
	torchvision/prototype/transforms/_augment.py
	torchvision/prototype/transforms/_auto_augment.py
	torchvision/prototype/transforms/_geometry.py

* make get_params private

* remove randbool method

* remove AutoAugmentDispatcher

* add high level kernels for meta conversion

* remove transforms meta abstraction from auto augment transforms

* appease mypy

* add smoke tests for transforms

* remove Query object

* remove extra_repr helper

* fix tests

* appease mypy

* revert some changes on the kernel tests

* fix dispatcher annotations

* remove float cast for torch.rand

* add helper to query image

* fix imports

* address auto augment comments

* cleanup
parent 144f0980
...@@ -24,8 +24,7 @@ from typing import ( ...@@ -24,8 +24,7 @@ from typing import (
Tuple, Tuple,
TypeVar, TypeVar,
Union, Union,
List, Optional,
Dict,
) )
import numpy as np import numpy as np
...@@ -42,6 +41,7 @@ __all__ = [ ...@@ -42,6 +41,7 @@ __all__ = [
"fromfile", "fromfile",
"ReadOnlyTensorBuffer", "ReadOnlyTensorBuffer",
"apply_recursively", "apply_recursively",
"query_recursively",
] ]
...@@ -305,22 +305,22 @@ def apply_recursively(fn: Callable, obj: Any) -> Any: ...@@ -305,22 +305,22 @@ def apply_recursively(fn: Callable, obj: Any) -> Any:
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]... # "a" == "a"[0][0]...
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
sequence: List[Any] = [] return [apply_recursively(fn, item) for item in obj]
for item in obj:
result = apply_recursively(fn, item)
if isinstance(result, collections.abc.Sequence) and hasattr(result, "__inline__"):
sequence.extend(result)
else:
sequence.append(result)
return sequence
elif isinstance(obj, collections.abc.Mapping): elif isinstance(obj, collections.abc.Mapping):
mapping: Dict[Any, Any] = {} return {key: apply_recursively(fn, item) for key, item in obj.items()}
for name, item in obj.items():
result = apply_recursively(fn, item)
if isinstance(result, collections.abc.Mapping) and hasattr(result, "__inline__"):
mapping.update(result)
else:
mapping[name] = result
return mapping
else: else:
return fn(obj) return fn(obj)
def query_recursively(fn: Callable[[Any], Optional[D]], obj: Any) -> Iterator[D]:
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if (isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)) or isinstance(
obj, collections.abc.Mapping
):
for item in obj.values() if isinstance(obj, collections.abc.Mapping) else obj:
yield from query_recursively(fn, item)
else:
result = fn(obj)
if result is not None:
yield result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment