Unverified Commit af7c6c04 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixed issues with dtype in geom functional transforms v2 (#7211)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent ea37cd38
......@@ -304,7 +304,7 @@ def make_image_loaders(
"RGBA",
),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.float32, torch.uint8),
dtypes=(torch.float32, torch.float64, torch.uint8),
constant_alpha=True,
):
for params in combinations_grid(size=sizes, color_space=color_spaces, extra_dims=extra_dims, dtype=dtypes):
......@@ -426,7 +426,7 @@ def make_bounding_box_loaders(
extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat),
spatial_size="random",
dtypes=(torch.float32, torch.int64),
dtypes=(torch.float32, torch.float64, torch.int64),
):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
......@@ -618,7 +618,7 @@ def make_video_loaders(
),
num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,),
dtypes=(torch.uint8, torch.float32, torch.float64),
):
for params in combinations_grid(
size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes
......
......@@ -109,6 +109,12 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
}
def scripted_vs_eager_double_pixel_difference(device, atol=1e-6, rtol=1e-6):
return {
(("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
}
def pil_reference_wrapper(pil_kernel):
@functools.wraps(pil_kernel)
def wrapper(input_tensor, *other_args, **kwargs):
......@@ -541,8 +547,10 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
def transform(bbox, affine_matrix_, format_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_box(
bbox.float(), old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
)
points = np.array(
[
......@@ -560,6 +568,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
np.max(transformed_points[:, 0]).item(),
np.max(transformed_points[:, 1]).item(),
],
dtype=bbox_xyxy.dtype,
)
out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
......@@ -844,6 +853,10 @@ KERNEL_INFOS.extend(
KernelInfo(
F.rotate_bounding_box,
sample_inputs_fn=sample_inputs_rotate_bounding_box,
closeness_kwargs={
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
},
),
KernelInfo(
F.rotate_mask,
......@@ -1275,6 +1288,8 @@ KERNEL_INFOS.extend(
**pil_reference_pixel_difference(2, mae=True),
**cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
},
),
KernelInfo(
......@@ -1294,7 +1309,11 @@ KERNEL_INFOS.extend(
KernelInfo(
F.perspective_video,
sample_inputs_fn=sample_inputs_perspective_video,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
closeness_kwargs={
**cuda_vs_cpu_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
},
),
]
)
......
......@@ -138,17 +138,28 @@ CONSISTENCY_CONFIGS = [
NotScriptableArgsKwargs(5, padding_mode="symmetric"),
],
),
ConsistencyConfig(
prototype_transforms.LinearTransformation,
legacy_transforms.LinearTransformation,
[
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX, LINEAR_TRANSFORMATION_MEAN),
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"]),
supports_pil=False,
),
*[
ConsistencyConfig(
prototype_transforms.LinearTransformation,
legacy_transforms.LinearTransformation,
[
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
),
supports_pil=False,
)
for matrix_dtype, image_dtype in [
(torch.float32, torch.float32),
(torch.float64, torch.float64),
(torch.float32, torch.uint8),
(torch.float64, torch.float32),
(torch.float32, torch.float64),
]
],
ConsistencyConfig(
prototype_transforms.Grayscale,
legacy_transforms.Grayscale,
......
......@@ -142,7 +142,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, **kwargs),
msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs),
)
def _unbatch(self, batch, *, data_dims):
......
......@@ -64,6 +64,11 @@ class LinearTransformation(Transform):
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
......@@ -93,7 +98,9 @@ class LinearTransformation(Transform):
)
flat_tensor = inpt.reshape(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
return transformed_tensor.reshape(shape)
......
......@@ -404,9 +404,13 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
def _apply_grid_transform(
float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT
img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT
) -> torch.Tensor:
# We are using context knowledge that grid should have float dtype
fp = img.dtype == grid.dtype
float_img = img if fp else img.to(grid.dtype)
shape = float_img.shape
if shape[0] > 1:
# Apply same grid to a batch of images
......@@ -433,7 +437,9 @@ def _apply_grid_transform(
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
return float_img
img = float_img.round_().to(img.dtype) if not fp else float_img
return img
def _assert_grid_transform_inputs(
......@@ -511,7 +517,6 @@ def affine_image_tensor(
shape = image.shape
ndim = image.ndim
fp = torch.is_floating_point(image)
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
......@@ -535,13 +540,10 @@ def affine_image_tensor(
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
dtype = image.dtype if fp else torch.float32
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
......@@ -612,7 +614,7 @@ def _affine_bounding_box_xyxy(
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
# 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, transposed_affine_matrix)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
......@@ -797,19 +799,15 @@ def rotate_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
if image.numel() > 0:
fp = torch.is_floating_point(image)
image = image.reshape(-1, num_channels, height, width)
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
dtype = image.dtype if fp else torch.float32
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
new_height, new_width = output.shape[-2:]
else:
......@@ -1237,9 +1235,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device)
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
......@@ -1283,7 +1281,6 @@ def perspective_image_tensor(
shape = image.shape
ndim = image.ndim
fp = torch.is_floating_point(image)
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
......@@ -1304,12 +1301,9 @@ def perspective_image_tensor(
)
oh, ow = shape[-2:]
dtype = image.dtype if fp else torch.float32
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
......@@ -1494,8 +1488,12 @@ def elastic_image_tensor(
shape = image.shape
ndim = image.ndim
device = image.device
fp = torch.is_floating_point(image)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
# We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be casted to float32 and all computations will be done with float32
# We can fix this later if needed
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
......@@ -1506,12 +1504,12 @@ def elastic_image_tensor(
else:
needs_unsquash = False
image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device))
output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill)
if displacement.dtype != dtype or displacement.device != device:
displacement = displacement.to(dtype=dtype, device=device)
if not fp:
output = output.round_().to(image.dtype)
image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
......@@ -1531,13 +1529,13 @@ def elastic_image_pil(
return to_pil_image(output, mode=image.mode)
def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch.Tensor:
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sy, sx = size
base_grid = torch.empty(1, sy, sx, 2, device=device)
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device)
base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device).unsqueeze_(-1)
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
return base_grid
......@@ -1552,7 +1550,11 @@ def elastic_bounding_box(
return bounding_box
# TODO: add in docstring about approximation we are doing for grid inversion
displacement = displacement.to(bounding_box.device)
device = bounding_box.device
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
if displacement.dtype != dtype or displacement.device != device:
displacement = displacement.to(dtype=dtype, device=device)
original_shape = bounding_box.shape
bounding_box = (
......@@ -1563,7 +1565,7 @@ def elastic_bounding_box(
# Or add spatial_size arg and check displacement shape
spatial_size = displacement.shape[-3], displacement.shape[-2]
id_grid = _create_identity_grid(spatial_size, bounding_box.device)
id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement)
......
......@@ -1078,6 +1078,11 @@ class LinearTransformation(torch.nn.Module):
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
......@@ -1105,9 +1110,10 @@ class LinearTransformation(torch.nn.Module):
)
flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(shape)
return tensor
transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
return transformed_tensor.view(shape)
def __repr__(self) -> str:
s = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment