Unverified Commit a8c4875c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port tests for transforms.ScaleJitter (#8001)

parent c5dea7d4
...@@ -519,34 +519,6 @@ class TestRandomIoUCrop: ...@@ -519,34 +519,6 @@ class TestRandomIoUCrop:
assert isinstance(output_masks, tv_tensors.Mask) assert isinstance(output_masks, tv_tensors.Mask)
class TestScaleJitter:
def test__get_params(self):
canvas_size = (24, 32)
target_size = (16, 12)
scale_range = (0.5, 1.5)
transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
sample = make_image(canvas_size)
n_samples = 5
for _ in range(n_samples):
params = transform._get_params([sample])
assert "size" in params
size = params["size"]
assert isinstance(size, tuple) and len(size) == 2
height, width = size
r_min = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[0]
r_max = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[1]
assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max)
assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max)
class TestRandomShortestSize: class TestRandomShortestSize:
@pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
def test__get_params(self, min_size, max_size): def test__get_params(self, min_size, max_size):
......
...@@ -4788,3 +4788,41 @@ class TestRandomPhotometricDistort: ...@@ -4788,3 +4788,41 @@ class TestRandomPhotometricDistort:
), ),
make_input(dtype=dtype, device=device), make_input(dtype=dtype, device=device),
) )
class TestScaleJitter:
# Tests are light because this largely relies on the already tested `resize` kernels.
INPUT_SIZE = (17, 11)
TARGET_SIZE = (12, 13)
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
if make_input is make_image_pil and device != "cpu":
pytest.skip("PIL image tests with parametrization device!='cpu' will degenerate to that anyway.")
check_transform(transforms.ScaleJitter(self.TARGET_SIZE), make_input(self.INPUT_SIZE, device=device))
def test__get_params(self):
input_size = self.INPUT_SIZE
target_size = self.TARGET_SIZE
scale_range = (0.5, 1.5)
transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
params = transform._get_params([make_image(input_size)])
assert "size" in params
size = params["size"]
assert isinstance(size, tuple) and len(size) == 2
height, width = size
r_min = min(target_size[1] / input_size[0], target_size[0] / input_size[1]) * scale_range[0]
r_max = min(target_size[1] / input_size[0], target_size[0] / input_size[1]) * scale_range[1]
assert int(input_size[0] * r_min) <= height <= int(input_size[0] * r_max)
assert int(input_size[1] * r_min) <= width <= int(input_size[1] * r_max)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment