"tests/python/common/data/test_data.py" did not exist on "c9c6171b9fb6b4903400ca03cdc3f6497498e516"
Unverified Commit ba64d65b authored by Thien Tran's avatar Thien Tran Committed by GitHub
Browse files

Fast rotation for right angles (#8295)


Co-authored-by: default avatarThien Tran <thien.tran@parallelchain.io>
parent c7bcfada
...@@ -1782,6 +1782,17 @@ class TestRotate: ...@@ -1782,6 +1782,17 @@ class TestRotate:
with pytest.raises(TypeError, match="Got inappropriate fill arg"): with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill") transforms.RandomAffine(degrees=0, fill="fill")
@pytest.mark.parametrize("size", [(11, 17), (16, 16)])
@pytest.mark.parametrize("angle", [0, 90, 180, 270])
@pytest.mark.parametrize("expand", [False, True])
def test_functional_image_fast_path_correctness(self, size, angle, expand):
image = make_image(size, dtype=torch.uint8, device="cpu")
actual = F.rotate(image, angle=angle, expand=expand)
expected = F.to_image(F.rotate(F.to_pil_image(image), angle=angle, expand=expand))
torch.testing.assert_close(actual, expected)
class TestContainerTransforms: class TestContainerTransforms:
class BuiltinTransform(transforms.Transform): class BuiltinTransform(transforms.Transform):
......
...@@ -997,6 +997,21 @@ def rotate_image( ...@@ -997,6 +997,21 @@ def rotate_image(
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
angle = angle % 360 # shift angle to [0, 360) range
# fast path: transpose without affine transform
if center is None:
if angle == 0:
return image.clone()
if angle == 180:
return torch.rot90(image, k=2, dims=(-2, -1))
if expand or image.shape[-1] == image.shape[-2]:
if angle == 90:
return torch.rot90(image, k=1, dims=(-2, -1))
if angle == 270:
return torch.rot90(image, k=3, dims=(-2, -1))
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
input_height, input_width = image.shape[-2:] input_height, input_width = image.shape[-2:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment