Unverified Commit 205602af authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix v2.Lambda (#7566)

parent e5a1b71d
......@@ -2137,3 +2137,38 @@ def test_no_warnings_v1_namespace():
from torchvision.datasets import ImageNet
"""
assert_run_python_script(textwrap.dedent(source))
class TestLambda:
inputs = pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
@inputs
def test_default(self, input):
was_applied = False
def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input
transform = transforms.Lambda(was_applied_fn)
transform(input)
assert was_applied
@inputs
def test_with_types(self, input):
was_applied = False
def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input
types = (torch.Tensor, np.ndarray)
transform = transforms.Lambda(was_applied_fn, *types)
transform(input)
assert was_applied is isinstance(input, types)
......@@ -32,10 +32,12 @@ class Lambda(Transform):
lambd (function): Lambda/function to be used for transform.
"""
_transformed_types = (object,)
def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__()
self.lambd = lambd
self.types = types or (object,)
self.types = types or self._transformed_types
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment