You need to sign in or sign up before continuing.
Unverified Commit 2aa54743 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port `ScaleJitter` from detection reference to prototype transforms (#6411)

* port ScaleJitter from detection reference to prototype transforms

* add test

* use MagicMock as sentinel
parent 7a1281a6
...@@ -1125,3 +1125,42 @@ class TestCompose: ...@@ -1125,3 +1125,42 @@ class TestCompose:
inpt = torch.rand(1, 3, 32, 32) inpt = torch.rand(1, 3, 32, 32)
output = c(inpt) output = c(inpt)
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
class TestScaleJitter:
def test__get_params(self, mocker):
image_size = (24, 32)
target_size = (16, 12)
scale_range = (0.5, 1.5)
transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size)
params = transform._get_params(sample)
assert "size" in params
size = params["size"]
assert isinstance(size, tuple) and len(size) == 2
height, width = size
assert int(target_size[0] * scale_range[0]) <= height <= int(target_size[0] * scale_range[1])
assert int(target_size[1] * scale_range[0]) <= width <= int(target_size[1] * scale_range[1])
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
transform = transforms.ScaleJitter(target_size=(16, 12), interpolation=interpolation_sentinel)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel)
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
...@@ -30,6 +30,7 @@ from ._geometry import ( ...@@ -30,6 +30,7 @@ from ._geometry import (
RandomVerticalFlip, RandomVerticalFlip,
RandomZoomOut, RandomZoomOut,
Resize, Resize,
ScaleJitter,
TenCrop, TenCrop,
) )
from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype
......
...@@ -631,3 +631,29 @@ class ElasticTransform(Transform): ...@@ -631,3 +631,29 @@ class ElasticTransform(Transform):
fill=self.fill, fill=self.fill,
interpolation=self.interpolation, interpolation=self.interpolation,
) )
class ScaleJitter(Transform):
def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, orig_height, orig_width = get_image_dimensions(image)
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
new_width = int(self.target_size[1] * r)
new_height = int(self.target_size[0] * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment