Unverified Commit a5b36d3b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port RandomResize from segmentation references (#6561)

* port RandomResize from segmentation references

* mypy

* remove optional max_size

* add interpolation and antialias

* [SKIP CI] only CircleCI

* fix bug

* [SKIP CI] only CircleCI

* [SKIP CI] add test
parent a89b1957
...@@ -1670,3 +1670,43 @@ class TestLabelToOneHot: ...@@ -1670,3 +1670,43 @@ class TestLabelToOneHot:
assert isinstance(ohe_labels, features.OneHotLabel) assert isinstance(ohe_labels, features.OneHotLabel)
assert ohe_labels.shape == (4, 3) assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories assert ohe_labels.categories == labels.categories == categories
class TestRandomResize:
def test__get_params(self):
min_size = 3
max_size = 6
transform = transforms.RandomResize(min_size=min_size, max_size=max_size)
for _ in range(10):
params = transform._get_params(None)
assert isinstance(params["size"], list) and len(params["size"]) == 1
size = params["size"][0]
assert min_size <= size < max_size
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomResize(
min_size=-1, max_size=-1, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.RandomResize._get_params",
return_value=dict(size=size_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock_resize = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
transform(inpt_sentinel)
mock_resize.assert_called_with(
inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
...@@ -28,6 +28,7 @@ from ._geometry import ( ...@@ -28,6 +28,7 @@ from ._geometry import (
RandomHorizontalFlip, RandomHorizontalFlip,
RandomIoUCrop, RandomIoUCrop,
RandomPerspective, RandomPerspective,
RandomResize,
RandomResizedCrop, RandomResizedCrop,
RandomRotation, RandomRotation,
RandomShortestSize, RandomShortestSize,
......
...@@ -866,3 +866,25 @@ class FixedSizeCrop(Transform): ...@@ -866,3 +866,25 @@ class FixedSizeCrop(Transform):
) )
return super().forward(*inputs) return super().forward(*inputs)
class RandomResize(Transform):
def __init__(
self,
min_size: int,
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.min_size = min_size
self.max_size = max_size
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
size = int(torch.randint(self.min_size, self.max_size, ()))
return dict(size=[size])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment