Unverified Commit c3573c88 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port `RandomShortestSize` from detection references to prototype transforms (#6418)

* port `RandomShortestSize` from detection references to prototype transforms

* mypy

* add test
parent c0ba3ec8
...@@ -1164,3 +1164,48 @@ class TestScaleJitter: ...@@ -1164,3 +1164,48 @@ class TestScaleJitter:
transform(inpt_sentinel) transform(inpt_sentinel)
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
class TestRandomShortestSize:
def test__get_params(self, mocker):
image_size = (3, 10)
min_size = [5, 9]
max_size = 20
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size)
params = transform._get_params(sample)
assert "size" in params
size = params["size"]
assert isinstance(size, tuple) and len(size) == 2
longer = max(size)
assert longer <= max_size
shorter = min(size)
if longer == max_size:
assert shorter <= max_size
else:
assert shorter in min_size
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
transform = transforms.RandomShortestSize(min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.RandomShortestSize._get_params",
return_value=dict(size=size_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
...@@ -27,6 +27,7 @@ from ._geometry import ( ...@@ -27,6 +27,7 @@ from ._geometry import (
RandomPerspective, RandomPerspective,
RandomResizedCrop, RandomResizedCrop,
RandomRotation, RandomRotation,
RandomShortestSize,
RandomVerticalFlip, RandomVerticalFlip,
RandomZoomOut, RandomZoomOut,
Resize, Resize,
......
...@@ -644,3 +644,31 @@ class ScaleJitter(Transform): ...@@ -644,3 +644,31 @@ class ScaleJitter(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation) return F.resize(inpt, size=params["size"], interpolation=self.interpolation)
class RandomShortestSize(Transform):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, orig_height, orig_width = get_image_dimensions(image)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment