Unverified Commit e1b21f9c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

introduce _check method for type checks on prototype transforms (#6503)



* introduce _check method for type checks on prototype transforms

* cleanup

* Update torchvision/prototype/transforms/_geometry.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* introduce _check on new transforms

* _check -> _check_inputs

* always check inputs in _RandomApplyTransform
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent bdc55567
......@@ -107,17 +107,16 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inputs: Any) -> Any:
def _check_inputs(self, sample: Any) -> None:
if not (
has_any(inputs, features.Image, features.Video, features.is_simple_tensor)
and has_any(inputs, features.OneHotLabel)
has_any(sample, features.Image, features.Video, features.is_simple_tensor)
and has_any(sample, features.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
if has_any(sample, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)
return super().forward(*inputs)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
if inpt.ndim < 2:
......
......@@ -184,10 +184,9 @@ class FiveCrop(Transform):
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any:
if has_any(inputs, features.BoundingBox, features.Mask):
def _check_inputs(self, sample: Any) -> None:
if has_any(sample, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs)
class TenCrop(Transform):
......@@ -202,16 +201,15 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _check_inputs(self, sample: Any) -> None:
if has_any(sample, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any:
if has_any(inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs)
class Pad(Transform):
def __init__(
......@@ -616,6 +614,17 @@ class RandomIoUCrop(Transform):
self.options = sampler_options
self.trials = trials
def _check_inputs(self, sample: Any) -> None:
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
)
def _get_params(self, sample: Any) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(sample)
bboxes = query_bounding_box(sample)
......@@ -688,18 +697,6 @@ class RandomIoUCrop(Transform):
return output
def forward(self, *inputs: Any) -> Any:
if not (
has_all(inputs, features.BoundingBox)
and has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(inputs, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
)
return super().forward(*inputs)
class ScaleJitter(Transform):
def __init__(
......@@ -774,6 +771,18 @@ class FixedSizeCrop(Transform):
self.padding_mode = padding_mode
def _check_inputs(self, sample: Any) -> None:
if not has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(sample, features.BoundingBox) and not has_any(sample, features.Label, features.OneHotLabel):
raise TypeError(
f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
def _get_params(self, sample: Any) -> Dict[str, Any]:
height, width = query_spatial_size(sample)
new_height = min(height, self.crop_height)
......@@ -850,20 +859,6 @@ class FixedSizeCrop(Transform):
return inpt
def forward(self, *inputs: Any) -> Any:
if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel):
raise TypeError(
f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
return super().forward(*inputs)
class RandomResize(Transform):
def __init__(
......
......@@ -63,12 +63,10 @@ class LinearTransformation(Transform):
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def forward(self, *inputs: Any) -> Any:
if has_any(inputs, PIL.Image.Image):
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
return super().forward(*inputs)
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
......@@ -104,16 +102,15 @@ class Normalize(Transform):
self.std = list(std)
self.inplace = inplace
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any:
if has_any(inpts, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
return super().forward(*inpts)
class GaussianBlur(Transform):
def __init__(
......
......@@ -23,6 +23,9 @@ class Transform(nn.Module):
super().__init__()
_log_api_usage_once(self)
def _check_inputs(self, sample: Any) -> None:
pass
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict()
......@@ -32,6 +35,8 @@ class Transform(nn.Module):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
self._check_inputs(sample)
params = self._get_params(sample)
flat_inputs, spec = tree_flatten(sample)
......@@ -64,9 +69,22 @@ class _RandomApplyTransform(Transform):
self.p = p
def forward(self, *inputs: Any) -> Any:
# We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return
# early afterwards in case the random check triggers. The same result could be achieved by calling
# `super().forward()` after the random check, but that would call `self._check_inputs` twice.
sample = inputs if len(inputs) > 1 else inputs[0]
self._check_inputs(sample)
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
params = self._get_params(sample)
flat_inputs, spec = tree_flatten(sample)
flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment