Unverified Commit 13ea9018 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Simplified code in overridden transform forward methods (#6504)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent d6444529
...@@ -101,15 +101,14 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -101,15 +101,14 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self.alpha = alpha self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inpts: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inpts if len(inpts) > 1 else inpts[0] if not (has_any(inputs, features.Image, is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
if not (has_any(sample, features.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label): if has_any(inputs, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError( raise TypeError(
f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels." f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels."
) )
return super().forward(sample) return super().forward(*inputs)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
if inpt.ndim < 2: if inpt.ndim < 2:
...@@ -325,9 +324,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -325,9 +324,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
c3 += 1 c3 += 1
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] flat_sample, spec = tree_flatten(inputs)
flat_sample, spec = tree_flatten(sample)
images, targets = self._extract_image_targets(flat_sample) images, targets = self._extract_image_targets(flat_sample)
......
...@@ -15,10 +15,9 @@ class Compose(Transform): ...@@ -15,10 +15,9 @@ class Compose(Transform):
self.transforms = transforms self.transforms = transforms
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms: for transform in self.transforms:
sample = transform(sample) inputs = transform(*inputs)
return sample return inputs
class RandomApply(_RandomApplyTransform): class RandomApply(_RandomApplyTransform):
......
...@@ -166,10 +166,9 @@ class FiveCrop(Transform): ...@@ -166,10 +166,9 @@ class FiveCrop(Transform):
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] if has_any(inputs, features.BoundingBox, features.SegmentationMask):
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample) return super().forward(*inputs)
class TenCrop(Transform): class TenCrop(Transform):
...@@ -188,10 +187,9 @@ class TenCrop(Transform): ...@@ -188,10 +187,9 @@ class TenCrop(Transform):
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] if has_any(inputs, features.BoundingBox, features.SegmentationMask):
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample) return super().forward(*inputs)
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None: def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
...@@ -696,17 +694,16 @@ class RandomIoUCrop(Transform): ...@@ -696,17 +694,16 @@ class RandomIoUCrop(Transform):
return output return output
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if not ( if not (
has_all(sample, features.BoundingBox) has_all(inputs, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor) and has_any(inputs, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel) and has_any(inputs, features.Label, features.OneHotLabel)
): ):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks." "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
) )
return super().forward(sample) return super().forward(*inputs)
class ScaleJitter(Transform): class ScaleJitter(Transform):
...@@ -850,15 +847,13 @@ class FixedSizeCrop(Transform): ...@@ -850,15 +847,13 @@ class FixedSizeCrop(Transform):
return inpt return inpt
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] if not has_any(inputs, PIL.Image.Image, features.Image, is_simple_tensor):
if not has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor):
raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.") raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.")
if has_any(sample, features.BoundingBox) and not has_any(sample, features.Label, features.OneHotLabel): if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel):
raise TypeError( raise TypeError(
f"If a BoundingBox is contained in the input sample, " f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
) )
return super().forward(sample) return super().forward(*inputs)
...@@ -63,11 +63,10 @@ class LinearTransformation(Transform): ...@@ -63,11 +63,10 @@ class LinearTransformation(Transform):
self.mean_vector = mean_vector self.mean_vector = mean_vector
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] if has_any(inputs, PIL.Image.Image):
if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images") raise TypeError("LinearTransformation does not work on PIL Images")
return super().forward(sample) return super().forward(*inputs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range # Image instance after linear transformation is not Image anymore due to unknown data range
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment