"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6f6eef747ce2943d2b76d845b55220fd07fcf177"
Unverified Commit 205602af authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix v2.Lambda (#7566)

parent e5a1b71d
...@@ -2137,3 +2137,38 @@ def test_no_warnings_v1_namespace(): ...@@ -2137,3 +2137,38 @@ def test_no_warnings_v1_namespace():
from torchvision.datasets import ImageNet from torchvision.datasets import ImageNet
""" """
assert_run_python_script(textwrap.dedent(source)) assert_run_python_script(textwrap.dedent(source))
class TestLambda:
inputs = pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
@inputs
def test_default(self, input):
was_applied = False
def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input
transform = transforms.Lambda(was_applied_fn)
transform(input)
assert was_applied
@inputs
def test_with_types(self, input):
was_applied = False
def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input
types = (torch.Tensor, np.ndarray)
transform = transforms.Lambda(was_applied_fn, *types)
transform(input)
assert was_applied is isinstance(input, types)
...@@ -32,10 +32,12 @@ class Lambda(Transform): ...@@ -32,10 +32,12 @@ class Lambda(Transform):
lambd (function): Lambda/function to be used for transform. lambd (function): Lambda/function to be used for transform.
""" """
_transformed_types = (object,)
def __init__(self, lambd: Callable[[Any], Any], *types: Type): def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__() super().__init__()
self.lambd = lambd self.lambd = lambd
self.types = types or (object,) self.types = types or self._transformed_types
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types): if isinstance(inpt, self.types):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment