Unverified Commit cc0a8385 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Use F.pad in RandomCrop (#6392)

* [proto] Use F.pad in RandomCrop

* Update torchvision/prototype/transforms/_geometry.py
parent 5521e9d0
...@@ -434,18 +434,17 @@ class RandomCrop(Transform): ...@@ -434,18 +434,17 @@ class RandomCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if pad_if_needed or padding is not None:
if padding is not None:
_check_padding_arg(padding)
_check_fill_arg(fill)
_check_padding_mode_arg(padding_mode)
self.padding = padding self.padding = padding
self.pad_if_needed = pad_if_needed self.pad_if_needed = pad_if_needed
self.fill = fill self.fill = fill
self.padding_mode = padding_mode self.padding_mode = padding_mode
self._pad_op = None
if self.padding is not None:
self._pad_op = Pad(self.padding, fill=self.fill, padding_mode=self.padding_mode)
if self.pad_if_needed:
self._pad_op = Pad(0, fill=self.fill, padding_mode=self.padding_mode)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) image = query_image(sample)
_, height, width = get_image_dimensions(image) _, height, width = get_image_dimensions(image)
...@@ -469,28 +468,21 @@ class RandomCrop(Transform): ...@@ -469,28 +468,21 @@ class RandomCrop(Transform):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
if self._pad_op is not None: if self.padding is not None:
sample = self._pad_op(sample) sample = F.pad(sample, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
image = query_image(sample) image = query_image(sample)
_, height, width = get_image_dimensions(image) _, height, width = get_image_dimensions(image)
if self.pad_if_needed: if self.pad_if_needed:
# This check is to explicitly ensure that self._pad_op is defined
if self._pad_op is None:
raise RuntimeError(
"Internal error, self._pad_op is None. "
"Please, fill an issue about that on https://github.com/pytorch/vision/issues"
)
# pad the width if needed # pad the width if needed
if width < self.size[1]: if width < self.size[1]:
self._pad_op.padding = [self.size[1] - width, 0] padding = [self.size[1] - width, 0]
sample = self._pad_op(sample) sample = F.pad(sample, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
# pad the height if needed # pad the height if needed
if height < self.size[0]: if height < self.size[0]:
self._pad_op.padding = [0, self.size[0] - height] padding = [0, self.size[0] - height]
sample = self._pad_op(sample) sample = F.pad(sample, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
return super().forward(sample) return super().forward(sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment