"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "c9b6c2df711d9e75607e7da55ca560d6a33f2c9f"
Unverified Commit e1b21f9c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

introduce _check method for type checks on prototype transforms (#6503)



* introduce _check method for type checks on prototype transforms

* cleanup

* Update torchvision/prototype/transforms/_geometry.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* introduce _check on new transforms

* _check -> _check_inputs

* always check inputs in _RandomApplyTransform
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent bdc55567
...@@ -107,17 +107,16 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -107,17 +107,16 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self.alpha = alpha self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inputs: Any) -> Any: def _check_inputs(self, sample: Any) -> None:
if not ( if not (
has_any(inputs, features.Image, features.Video, features.is_simple_tensor) has_any(sample, features.Image, features.Video, features.is_simple_tensor)
and has_any(inputs, features.OneHotLabel) and has_any(sample, features.OneHotLabel)
): ):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label): if has_any(sample, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError( raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
) )
return super().forward(*inputs)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
if inpt.ndim < 2: if inpt.ndim < 2:
......
...@@ -184,10 +184,9 @@ class FiveCrop(Transform): ...@@ -184,10 +184,9 @@ class FiveCrop(Transform):
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any: def _check_inputs(self, sample: Any) -> None:
if has_any(inputs, features.BoundingBox, features.Mask): if has_any(sample, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs)
class TenCrop(Transform): class TenCrop(Transform):
...@@ -202,16 +201,15 @@ class TenCrop(Transform): ...@@ -202,16 +201,15 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _check_inputs(self, sample: Any) -> None:
if has_any(sample, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform( def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]: ) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any:
if has_any(inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs)
class Pad(Transform): class Pad(Transform):
def __init__( def __init__(
...@@ -616,6 +614,17 @@ class RandomIoUCrop(Transform): ...@@ -616,6 +614,17 @@ class RandomIoUCrop(Transform):
self.options = sampler_options self.options = sampler_options
self.trials = trials self.trials = trials
def _check_inputs(self, sample: Any) -> None:
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(sample) orig_h, orig_w = query_spatial_size(sample)
bboxes = query_bounding_box(sample) bboxes = query_bounding_box(sample)
...@@ -688,18 +697,6 @@ class RandomIoUCrop(Transform): ...@@ -688,18 +697,6 @@ class RandomIoUCrop(Transform):
return output return output
def forward(self, *inputs: Any) -> Any:
if not (
has_all(inputs, features.BoundingBox)
and has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(inputs, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
)
return super().forward(*inputs)
class ScaleJitter(Transform): class ScaleJitter(Transform):
def __init__( def __init__(
...@@ -774,6 +771,18 @@ class FixedSizeCrop(Transform): ...@@ -774,6 +771,18 @@ class FixedSizeCrop(Transform):
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _check_inputs(self, sample: Any) -> None:
if not has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(sample, features.BoundingBox) and not has_any(sample, features.Label, features.OneHotLabel):
raise TypeError(
f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
height, width = query_spatial_size(sample) height, width = query_spatial_size(sample)
new_height = min(height, self.crop_height) new_height = min(height, self.crop_height)
...@@ -850,20 +859,6 @@ class FixedSizeCrop(Transform): ...@@ -850,20 +859,6 @@ class FixedSizeCrop(Transform):
return inpt return inpt
def forward(self, *inputs: Any) -> Any:
if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel):
raise TypeError(
f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
return super().forward(*inputs)
class RandomResize(Transform): class RandomResize(Transform):
def __init__( def __init__(
......
...@@ -63,12 +63,10 @@ class LinearTransformation(Transform): ...@@ -63,12 +63,10 @@ class LinearTransformation(Transform):
self.transformation_matrix = transformation_matrix self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector self.mean_vector = mean_vector
def forward(self, *inputs: Any) -> Any: def _check_inputs(self, sample: Any) -> Any:
if has_any(inputs, PIL.Image.Image): if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images") raise TypeError("LinearTransformation does not work on PIL Images")
return super().forward(*inputs)
def _transform( def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -104,16 +102,15 @@ class Normalize(Transform): ...@@ -104,16 +102,15 @@ class Normalize(Transform):
self.std = list(std) self.std = list(std)
self.inplace = inplace self.inplace = inplace
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform( def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor: ) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any:
if has_any(inpts, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
return super().forward(*inpts)
class GaussianBlur(Transform): class GaussianBlur(Transform):
def __init__( def __init__(
......
...@@ -23,6 +23,9 @@ class Transform(nn.Module): ...@@ -23,6 +23,9 @@ class Transform(nn.Module):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
def _check_inputs(self, sample: Any) -> None:
pass
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict() return dict()
...@@ -32,6 +35,8 @@ class Transform(nn.Module): ...@@ -32,6 +35,8 @@ class Transform(nn.Module):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
self._check_inputs(sample)
params = self._get_params(sample) params = self._get_params(sample)
flat_inputs, spec = tree_flatten(sample) flat_inputs, spec = tree_flatten(sample)
...@@ -64,9 +69,22 @@ class _RandomApplyTransform(Transform): ...@@ -64,9 +69,22 @@ class _RandomApplyTransform(Transform):
self.p = p self.p = p
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
# We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return
# early afterwards in case the random check triggers. The same result could be achieved by calling
# `super().forward()` after the random check, but that would call `self._check_inputs` twice.
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
self._check_inputs(sample)
if torch.rand(1) >= self.p: if torch.rand(1) >= self.p:
return sample return sample
return super().forward(sample) params = self._get_params(sample)
flat_inputs, spec = tree_flatten(sample)
flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment