"src/vscode:/vscode.git/clone" did not exist on "d87ce2cefc6612fa95cb6d58fa3d74080d18b312"
Unverified Commit a26534c9 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixed rotate with expand inconsistency (#5677)

* Fixed rotate with expand inconsistency between torch vs PIL on odd-sized images

* Update functional_tensor.py
parent 71907be1
...@@ -67,7 +67,7 @@ class TestRotate: ...@@ -67,7 +67,7 @@ class TestRotate:
IMG_W = 26 IMG_W = 26
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("height, width", [(26, IMG_W), (32, IMG_W)]) @pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"center", "center",
[ [
...@@ -77,7 +77,7 @@ class TestRotate: ...@@ -77,7 +77,7 @@ class TestRotate:
], ],
) )
@pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize("angle", range(-180, 180, 17)) @pytest.mark.parametrize("angle", range(-180, 180, 34))
@pytest.mark.parametrize("expand", [True, False]) @pytest.mark.parametrize("expand", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fill", "fill",
......
...@@ -650,6 +650,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] ...@@ -650,6 +650,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
pts = torch.tensor( pts = torch.tensor(
[ [
[-0.5 * w, -0.5 * h, 1.0], [-0.5 * w, -0.5 * h, 1.0],
...@@ -658,11 +660,15 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] ...@@ -658,11 +660,15 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
[0.5 * w, -0.5 * h, 1.0], [0.5 * w, -0.5 * h, 1.0],
] ]
) )
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) new_pts = torch.matmul(pts, theta.T)
min_vals, _ = new_pts.min(dim=0) min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0) max_vals, _ = new_pts.max(dim=0)
# shift points to [0, w] and [0, h] interval to match PIL results
min_vals += torch.tensor((w * 0.5, h * 0.5))
max_vals += torch.tensor((w * 0.5, h * 0.5))
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4 tol = 1e-4
cmax = torch.ceil((max_vals / tol).trunc_() * tol) cmax = torch.ceil((max_vals / tol).trunc_() * tol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment