Unverified Commit 70edf96d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Port elastic and minor cleanups (#6942)



* Port elastic and minor cleanups

* Update torchvision/prototype/transforms/functional/_geometry.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 0cc90808
......@@ -388,7 +388,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
maximum = float_image.amax(dim=(-2, -1), keepdim=True)
eq_idxs = maximum == minimum
inv_scale = maximum.sub_(minimum).div_(bound)
inv_scale = maximum.sub_(minimum).mul_(1.0 / bound)
minimum[eq_idxs] = 0.0
inv_scale[eq_idxs] = 1.0
......
......@@ -390,7 +390,7 @@ def _affine_bounding_box_xyxy(
device=device,
)
new_points = torch.matmul(points, transposed_affine_matrix)
tr, _ = torch.min(new_points, dim=0, keepdim=True)
tr = torch.amin(new_points, dim=0, keepdim=True)
# Translate bounding boxes
out_bboxes.sub_(tr.repeat((1, 2)))
# Estimate meta-data for image with inverted=True and with center=[0,0]
......@@ -701,7 +701,7 @@ def pad_image_tensor(
# internally.
torch_padding = _parse_pad_padding(padding)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
raise ValueError(
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
f"but got `'{padding_mode}'`."
......@@ -917,7 +917,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
# TODO: should we define them transposed?
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
......@@ -925,9 +925,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
......@@ -1059,6 +1059,7 @@ def perspective_bounding_box(
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
]
# TODO: should we define them transposed?
theta1 = torch.tensor(
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
dtype=dtype,
......@@ -1165,6 +1166,7 @@ def elastic_image_tensor(
return image
shape = image.shape
device = image.device
if image.ndim > 4:
image = image.reshape((-1,) + shape[-3:])
......@@ -1172,7 +1174,9 @@ def elastic_image_tensor(
else:
needs_unsquash = False
output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device))
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill)
if needs_unsquash:
output = output.reshape(shape)
......@@ -1505,8 +1509,7 @@ def five_crop_image_tensor(
image_height, image_width = image.shape[-2:]
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
......@@ -1525,8 +1528,7 @@ def five_crop_image_pil(
image_height, image_width = get_spatial_size_image_pil(image)
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment