Unverified Commit 1d7d92cb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

refactor `prototype.transforms.RandomCrop` (#6640)



* refactor RandomCrop

* mypy

* fix test

* use padding directly rather than private attribute

* only compute type specific fill if padding is needed

* [DRAFT] don't use the diff trick

* fix error message
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>

* remove height and width diff

* reinstate separate diff checking

* introduce needs_crop flag
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent d7d90f56
...@@ -715,30 +715,38 @@ class TestRandomCrop: ...@@ -715,30 +715,38 @@ class TestRandomCrop:
if padding is not None: if padding is not None:
if isinstance(padding, int): if isinstance(padding, int):
h += 2 * padding pad_top = pad_bottom = pad_left = pad_right = padding
w += 2 * padding
elif isinstance(padding, list) and len(padding) == 2: elif isinstance(padding, list) and len(padding) == 2:
w += 2 * padding[0] pad_left = pad_right = padding[0]
h += 2 * padding[1] pad_top = pad_bottom = padding[1]
elif isinstance(padding, list) and len(padding) == 4: elif isinstance(padding, list) and len(padding) == 4:
w += padding[0] + padding[2] pad_left, pad_top, pad_right, pad_bottom = padding
h += padding[1] + padding[3]
expected_input_width = w h += pad_top + pad_bottom
expected_input_height = h w += pad_left + pad_right
else:
pad_left = pad_right = pad_top = pad_bottom = 0
if pad_if_needed: if pad_if_needed:
if w < size[1]: if w < size[1]:
w += 2 * (size[1] - w) diff = size[1] - w
pad_left += diff
pad_right += diff
w += 2 * diff
if h < size[0]: if h < size[0]:
h += 2 * (size[0] - h) diff = size[0] - h
pad_top += diff
pad_bottom += diff
h += 2 * diff
padding = [pad_left, pad_top, pad_right, pad_bottom]
assert 0 <= params["top"] <= h - size[0] + 1 assert 0 <= params["top"] <= h - size[0] + 1
assert 0 <= params["left"] <= w - size[1] + 1 assert 0 <= params["left"] <= w - size[1] + 1
assert params["height"] == size[0] assert params["height"] == size[0]
assert params["width"] == size[1] assert params["width"] == size[1]
assert params["input_width"] == expected_input_width assert params["needs_pad"] is any(padding)
assert params["input_height"] == expected_input_height assert params["padding"] == padding
@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
@pytest.mark.parametrize("pad_if_needed", [False, True]) @pytest.mark.parametrize("pad_if_needed", [False, True])
......
...@@ -966,7 +966,7 @@ class PadIfSmaller(prototype_transforms.Transform): ...@@ -966,7 +966,7 @@ class PadIfSmaller(prototype_transforms.Transform):
class TestRefSegTransforms: class TestRefSegTransforms:
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
size = (256, 640) size = (256, 460)
num_categories = 21 num_categories = 21
conv_fns = [] conv_fns = []
......
...@@ -414,78 +414,80 @@ class RandomCrop(Transform): ...@@ -414,78 +414,80 @@ class RandomCrop(Transform):
_check_padding_arg(padding) _check_padding_arg(padding)
_check_padding_mode_arg(padding_mode) _check_padding_mode_arg(padding_mode)
self.padding = padding self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
self.pad_if_needed = pad_if_needed self.pad_if_needed = pad_if_needed
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, height, width = query_chw(sample) _, padded_height, padded_width = query_chw(sample)
if self.padding is not None: if self.padding is not None:
# update height, width with static padding data pad_left, pad_right, pad_top, pad_bottom = self.padding
padding = self.padding padded_height += pad_top + pad_bottom
if isinstance(padding, Sequence): padded_width += pad_left + pad_right
padding = list(padding) else:
pad_left, pad_right, pad_top, pad_bottom = F._geometry._parse_pad_padding(padding) pad_left = pad_right = pad_top = pad_bottom = 0
height += pad_top + pad_bottom
width += pad_left + pad_right
output_height, output_width = self.size cropped_height, cropped_width = self.size
# We have to store maybe padded image size for pad_if_needed branch in _transform
input_height, input_width = height, width
if self.pad_if_needed: if self.pad_if_needed:
# pad width if needed if padded_height < cropped_height:
if width < output_width: diff = cropped_height - padded_height
width += 2 * (output_width - width)
# pad height if needed pad_top += diff
if height < output_height: pad_bottom += diff
height += 2 * (output_height - height) padded_height += 2 * diff
if height < output_height or width < output_width: if padded_width < cropped_width:
diff = cropped_width - padded_width
pad_left += diff
pad_right += diff
padded_width += 2 * diff
if padded_height < cropped_height or padded_width < cropped_width:
raise ValueError( raise ValueError(
f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}" f"Required crop size {(cropped_height, cropped_width)} is larger than "
f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}."
) )
if width == output_width and height == output_height: # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
return dict(top=0, left=0, height=height, width=width, input_width=input_width, input_height=input_height) padding = [pad_left, pad_top, pad_right, pad_bottom]
needs_pad = any(padding)
top = torch.randint(0, height - output_height + 1, size=(1,)).item() needs_vert_crop, top = (
left = torch.randint(0, width - output_width + 1, size=(1,)).item() (True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
if padded_height > cropped_height
else (False, 0)
)
needs_horz_crop, left = (
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
if padded_width > cropped_width
else (False, 0)
)
return dict( return dict(
needs_crop=needs_vert_crop or needs_horz_crop,
top=top, top=top,
left=left, left=left,
height=output_height, height=cropped_height,
width=output_width, width=cropped_width,
input_width=input_width, needs_pad=needs_pad,
input_height=input_height, padding=padding,
) )
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls if params["needs_pad"]:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill) fill = F._geometry._convert_fill_arg(fill)
if self.padding is not None: inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
padding = self.padding
if not isinstance(padding, int):
padding = list(padding)
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]:
inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
if self.pad_if_needed: return inpt
input_width, input_height = params["input_width"], params["input_height"]
if input_width < self.size[1]:
padding = [self.size[1] - input_width, 0]
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
if input_height < self.size[0]:
padding = [0, self.size[0] - input_height]
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
class RandomPerspective(_RandomApplyTransform): class RandomPerspective(_RandomApplyTransform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment