Unverified Commit b28d480f authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Refactor transforms input checks (#2694)

* Refactor transforms input checks
- added _setup_size to check and setup size argument

* More refactor
parent c163fc42
......@@ -234,11 +234,7 @@ class Resize(torch.nn.Module):
def __init__(self, size, interpolation=Image.BILINEAR):
super().__init__()
if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError("If size is a sequence, it should have 1 or 2 values")
self.size = size
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
self.interpolation = interpolation
def forward(self, img):
......@@ -279,15 +275,7 @@ class CenterCrop(torch.nn.Module):
def __init__(self, size):
super().__init__()
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else:
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img):
"""
......@@ -523,16 +511,11 @@ class RandomCrop(torch.nn.Module):
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
super().__init__()
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else:
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
# cast to tuple for torchscript
self.size = tuple(size)
self.size = tuple(_setup_size(
size, error_msg="Please provide only two dimensions (h, w) for size."
))
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
......@@ -728,14 +711,7 @@ class RandomResizedCrop(torch.nn.Module):
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
super().__init__()
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else:
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
......@@ -856,15 +832,7 @@ class FiveCrop(torch.nn.Module):
def __init__(self, size):
super().__init__()
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else:
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img):
"""
......@@ -912,15 +880,7 @@ class TenCrop(torch.nn.Module):
def __init__(self, size, vertical_flip=False):
super().__init__()
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else:
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def forward(self, img):
......@@ -1143,23 +1103,10 @@ class RandomRotation(torch.nn.Module):
def __init__(self, degrees, resample=False, expand=False, center=None, fill=None):
super().__init__()
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
degrees = [-degrees, degrees]
else:
if not isinstance(degrees, Sequence):
raise TypeError("degrees should be a sequence of length 2.")
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
self.degrees = [float(d) for d in degrees]
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
if center is not None:
if not isinstance(center, Sequence):
raise TypeError("center should be a sequence of length 2.")
if len(center) != 2:
raise ValueError("center should be a sequence of length 2.")
_check_sequence_input(center, "center", req_sizes=(2, ))
self.center = center
......@@ -1234,51 +1181,24 @@ class RandomAffine(torch.nn.Module):
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, fillcolor=0):
super().__init__()
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
degrees = [-degrees, degrees]
else:
if not isinstance(degrees, Sequence):
raise TypeError("degrees should be a sequence of length 2.")
if len(degrees) != 2:
raise ValueError("degrees should be sequence of length 2.")
self.degrees = [float(d) for d in degrees]
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
if translate is not None:
if not isinstance(translate, Sequence):
raise TypeError("translate should be a sequence of length 2.")
if len(translate) != 2:
raise ValueError("translate should be sequence of length 2.")
_check_sequence_input(translate, "translate", req_sizes=(2, ))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
if not isinstance(scale, Sequence):
raise TypeError("scale should be a sequence of length 2.")
if len(scale) != 2:
raise ValueError("scale should be sequence of length 2.")
_check_sequence_input(scale, "scale", req_sizes=(2, ))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
shear = [-shear, shear]
else:
if not isinstance(shear, Sequence):
raise TypeError("shear should be a sequence of length 2 or 4.")
if len(shear) not in (2, 4):
raise ValueError("shear should be sequence of length 2 or 4.")
self.shear = [float(s) for s in shear]
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
......@@ -1545,3 +1465,35 @@ class RandomErasing(torch.nn.Module):
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
return F.erase(img, x, y, h, w, v, self.inplace)
return img
def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
if isinstance(size, Sequence) and len(size) == 1:
return size[0], size[0]
if len(size) != 2:
raise ValueError(error_msg)
return size
def _check_sequence_input(x, name, req_sizes):
msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
if not isinstance(x, Sequence):
raise TypeError("{} should be a sequence of length {}.".format(name, msg))
if len(x) not in req_sizes:
raise ValueError("{} should be sequence of length {}.".format(name, msg))
def _setup_angle(x, name, req_sizes=(2, )):
if isinstance(x, numbers.Number):
if x < 0:
raise ValueError("If {} is a single number, it must be positive.".format(name))
x = [-x, x]
else:
_check_sequence_input(x, name, req_sizes)
return [float(d) for d in x]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment