Unverified Commit 13ea9018 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Simplified code in overridden transform forward methods (#6504)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent d6444529
......@@ -101,15 +101,14 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inpts: Any) -> Any:
sample = inpts if len(inpts) > 1 else inpts[0]
if not (has_any(sample, features.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)):
def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
if has_any(inputs, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels."
)
return super().forward(sample)
return super().forward(*inputs)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
if inpt.ndim < 2:
......@@ -325,9 +324,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
c3 += 1
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
flat_sample, spec = tree_flatten(sample)
flat_sample, spec = tree_flatten(inputs)
images, targets = self._extract_image_targets(flat_sample)
......
......@@ -15,10 +15,9 @@ class Compose(Transform):
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample)
return sample
inputs = transform(*inputs)
return inputs
class RandomApply(_RandomApplyTransform):
......
......@@ -166,10 +166,9 @@ class FiveCrop(Transform):
return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
if has_any(inputs, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
return super().forward(*inputs)
class TenCrop(Transform):
......@@ -188,10 +187,9 @@ class TenCrop(Transform):
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
if has_any(inputs, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
return super().forward(*inputs)
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
......@@ -696,17 +694,16 @@ class RandomIoUCrop(Transform):
return output
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
has_all(inputs, features.BoundingBox)
and has_any(inputs, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(inputs, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
)
return super().forward(sample)
return super().forward(*inputs)
class ScaleJitter(Transform):
......@@ -850,15 +847,13 @@ class FixedSizeCrop(Transform):
return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if not has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor):
if not has_any(inputs, PIL.Image.Image, features.Image, is_simple_tensor):
raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.")
if has_any(sample, features.BoundingBox) and not has_any(sample, features.Label, features.OneHotLabel):
if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel):
raise TypeError(
f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
return super().forward(sample)
return super().forward(*inputs)
......@@ -63,11 +63,10 @@ class LinearTransformation(Transform):
self.mean_vector = mean_vector
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, PIL.Image.Image):
if has_any(inputs, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
return super().forward(sample)
return super().forward(*inputs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment