Unverified Commit 9d15f2e6 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix prototype transforms for non features (#4942)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent a96de03d
...@@ -38,7 +38,7 @@ def test_feature_type_support(): ...@@ -38,7 +38,7 @@ def test_feature_type_support():
[transform_type for transform_type in TRANSFORM_TYPES if transform_type is not transforms.Identity], [transform_type for transform_type in TRANSFORM_TYPES if transform_type is not transforms.Identity],
ids=lambda transform_type: transform_type.__name__, ids=lambda transform_type: transform_type.__name__,
) )
def test_no_op(transform_type): def test_feature_no_op_coverage(transform_type):
unsupported_features = ( unsupported_features = (
FEATURE_TYPES - transform_type.supported_feature_types() - set(transform_type.NO_OP_FEATURE_TYPES) FEATURE_TYPES - transform_type.supported_feature_types() - set(transform_type.NO_OP_FEATURE_TYPES)
) )
...@@ -49,3 +49,13 @@ def test_no_op(transform_type): ...@@ -49,3 +49,13 @@ def test_no_op(transform_type):
f"no-op for transform `{transform_type.__name__}`. Please either implement a feature transform for them, " f"no-op for transform `{transform_type.__name__}`. Please either implement a feature transform for them, "
f"or add them to the the `{transform_type.__name__}.NO_OP_FEATURE_TYPES` collection." f"or add them to the the `{transform_type.__name__}.NO_OP_FEATURE_TYPES` collection."
) )
def test_non_feature_no_op():
class TestTransform(transforms.Transform):
@staticmethod
def image(input):
return input
no_op_sample = dict(int=0, float=0.0, bool=False, str="str")
assert TestTransform()(no_op_sample) == no_op_sample
...@@ -351,14 +351,16 @@ class Transform(nn.Module): ...@@ -351,14 +351,16 @@ class Transform(nn.Module):
sample: Sample. sample: Sample.
params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``.
""" """
if isinstance(sample, collections.abc.Sequence): # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if isinstance(sample, collections.abc.Sequence) and not isinstance(sample, str):
return [self._transform_recursively(item, params=params) for item in sample] return [self._transform_recursively(item, params=params) for item in sample]
elif isinstance(sample, collections.abc.Mapping): elif isinstance(sample, collections.abc.Mapping):
return {name: self._transform_recursively(item, params=params) for name, item in sample.items()} return {name: self._transform_recursively(item, params=params) for name, item in sample.items()}
else: else:
feature_type = type(sample) feature_type = type(sample)
if not self.supports(feature_type): if not self.supports(feature_type):
if feature_type in self.NO_OP_FEATURE_TYPES: if not issubclass(feature_type, features.Feature) or feature_type in self.NO_OP_FEATURE_TYPES:
return sample return sample
raise TypeError( raise TypeError(
...@@ -366,7 +368,7 @@ class Transform(nn.Module): ...@@ -366,7 +368,7 @@ class Transform(nn.Module):
f"If you want it to be a no-op, add the feature type to {type(self).__name__}.NO_OP_FEATURE_TYPES." f"If you want it to be a no-op, add the feature type to {type(self).__name__}.NO_OP_FEATURE_TYPES."
) )
return self.transform(sample, **params) return self.transform(cast(Union[torch.Tensor, features.Feature], sample), **params)
def get_params(self, sample: Any) -> Dict[str, Any]: def get_params(self, sample: Any) -> Dict[str, Any]:
"""Returns the parameter dictionary used to transform the current sample. """Returns the parameter dictionary used to transform the current sample.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment