Unverified Commit 55477476 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Improved functional tensor geom transforms to work on floatX dtype (#2661)

* Improved functional tensor geom transforms to work on floatX dtype
- Fixes #2600
- added tests
- refactored test_affine

* Removed float16/cpu case
parent 6662b30a
This diff is collapsed.
......@@ -718,9 +718,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
out_dtype = img.dtype
need_cast = False
if img.dtype not in (torch.float32, torch.float64):
if out_dtype != grid.dtype:
need_cast = True
img = img.to(torch.float32)
img = img.to(grid)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
......@@ -777,7 +777,8 @@ def affine(
_assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
......@@ -842,7 +843,8 @@ def rotate(
_assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes)
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
mode = _interpolation_modes[resample]
......@@ -850,7 +852,7 @@ def rotate(
return _apply_grid_transform(img, grid, mode)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.device):
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device):
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
......@@ -858,23 +860,22 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.devic
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor([[
[coeffs[0], coeffs[1], coeffs[2]],
[coeffs[3], coeffs[4], coeffs[5]]
]], dtype=torch.float, device=device)
]], dtype=dtype, device=device)
theta2 = torch.tensor([[
[coeffs[6], coeffs[7], 1.0],
[coeffs[6], coeffs[7], 1.0]
]], dtype=torch.float, device=device)
]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=torch.float, device=device)
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow))
base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1))
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=torch.float, device=device)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
......@@ -915,7 +916,8 @@ def perspective(
)
ow, oh = img.shape[-1], img.shape[-2]
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, device=img.device)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
mode = _interpolation_modes[interpolation]
return _apply_grid_transform(img, grid, mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment