Unverified Commit 55477476 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Improved functional tensor geom transforms to work on floatX dtype (#2661)

* Improved functional tensor geom transforms to work on floatX dtype
- Fixes #2600
- added tests
- refactored test_affine

* Removed float16/cpu case
parent 6662b30a
This diff is collapsed.
...@@ -718,9 +718,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: ...@@ -718,9 +718,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
out_dtype = img.dtype out_dtype = img.dtype
need_cast = False need_cast = False
if img.dtype not in (torch.float32, torch.float64): if out_dtype != grid.dtype:
need_cast = True need_cast = True
img = img.to(torch.float32) img = img.to(grid)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
...@@ -777,7 +777,8 @@ def affine( ...@@ -777,7 +777,8 @@ def affine(
_assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes) _assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3) dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape shape = img.shape
# grid will be generated on the same device as theta and img # grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
...@@ -842,7 +843,8 @@ def rotate( ...@@ -842,7 +843,8 @@ def rotate(
_assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes)
w, h = img.shape[-1], img.shape[-2] w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3) dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img # grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
mode = _interpolation_modes[resample] mode = _interpolation_modes[resample]
...@@ -850,7 +852,7 @@ def rotate( ...@@ -850,7 +852,7 @@ def rotate(
return _apply_grid_transform(img, grid, mode) return _apply_grid_transform(img, grid, mode)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.device): def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device):
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394 # src/libImaging/Geometry.c#L394
...@@ -858,23 +860,22 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.devic ...@@ -858,23 +860,22 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.devic
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
# #
theta1 = torch.tensor([[ theta1 = torch.tensor([[
[coeffs[0], coeffs[1], coeffs[2]], [coeffs[0], coeffs[1], coeffs[2]],
[coeffs[3], coeffs[4], coeffs[5]] [coeffs[3], coeffs[4], coeffs[5]]
]], dtype=torch.float, device=device) ]], dtype=dtype, device=device)
theta2 = torch.tensor([[ theta2 = torch.tensor([[
[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0],
[coeffs[6], coeffs[7], 1.0] [coeffs[6], coeffs[7], 1.0]
]], dtype=torch.float, device=device) ]], dtype=dtype, device=device)
d = 0.5 d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=torch.float, device=device) base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow)) base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow))
base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1)) base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1))
base_grid[..., 2].fill_(1) base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=torch.float, device=device) rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
...@@ -915,7 +916,8 @@ def perspective( ...@@ -915,7 +916,8 @@ def perspective(
) )
ow, oh = img.shape[-1], img.shape[-2] ow, oh = img.shape[-1], img.shape[-2]
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, device=img.device) dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
mode = _interpolation_modes[interpolation] mode = _interpolation_modes[interpolation]
return _apply_grid_transform(img, grid, mode) return _apply_grid_transform(img, grid, mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment