Unverified Commit af7c6c04 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixed issues with dtype in geom functional transforms v2 (#7211)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent ea37cd38
...@@ -304,7 +304,7 @@ def make_image_loaders( ...@@ -304,7 +304,7 @@ def make_image_loaders(
"RGBA", "RGBA",
), ),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.float32, torch.uint8), dtypes=(torch.float32, torch.float64, torch.uint8),
constant_alpha=True, constant_alpha=True,
): ):
for params in combinations_grid(size=sizes, color_space=color_spaces, extra_dims=extra_dims, dtype=dtypes): for params in combinations_grid(size=sizes, color_space=color_spaces, extra_dims=extra_dims, dtype=dtypes):
...@@ -426,7 +426,7 @@ def make_bounding_box_loaders( ...@@ -426,7 +426,7 @@ def make_bounding_box_loaders(
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat), formats=tuple(datapoints.BoundingBoxFormat),
spatial_size="random", spatial_size="random",
dtypes=(torch.float32, torch.int64), dtypes=(torch.float32, torch.float64, torch.int64),
): ):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes): for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, spatial_size=spatial_size) yield make_bounding_box_loader(**params, spatial_size=spatial_size)
...@@ -618,7 +618,7 @@ def make_video_loaders( ...@@ -618,7 +618,7 @@ def make_video_loaders(
), ),
num_frames=(1, 0, "random"), num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,), dtypes=(torch.uint8, torch.float32, torch.float64),
): ):
for params in combinations_grid( for params in combinations_grid(
size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes
......
...@@ -109,6 +109,12 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False): ...@@ -109,6 +109,12 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
} }
def scripted_vs_eager_double_pixel_difference(device, atol=1e-6, rtol=1e-6):
return {
(("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
}
def pil_reference_wrapper(pil_kernel): def pil_reference_wrapper(pil_kernel):
@functools.wraps(pil_kernel) @functools.wraps(pil_kernel)
def wrapper(input_tensor, *other_args, **kwargs): def wrapper(input_tensor, *other_args, **kwargs):
...@@ -541,8 +547,10 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix) ...@@ -541,8 +547,10 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
def transform(bbox, affine_matrix_, format_): def transform(bbox, affine_matrix_, format_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_box( bbox_xyxy = F.convert_format_bounding_box(
bbox.float(), old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
) )
points = np.array( points = np.array(
[ [
...@@ -560,6 +568,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix) ...@@ -560,6 +568,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
np.max(transformed_points[:, 0]).item(), np.max(transformed_points[:, 0]).item(),
np.max(transformed_points[:, 1]).item(), np.max(transformed_points[:, 1]).item(),
], ],
dtype=bbox_xyxy.dtype,
) )
out_bbox = F.convert_format_bounding_box( out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
...@@ -844,6 +853,10 @@ KERNEL_INFOS.extend( ...@@ -844,6 +853,10 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.rotate_bounding_box, F.rotate_bounding_box,
sample_inputs_fn=sample_inputs_rotate_bounding_box, sample_inputs_fn=sample_inputs_rotate_bounding_box,
closeness_kwargs={
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
},
), ),
KernelInfo( KernelInfo(
F.rotate_mask, F.rotate_mask,
...@@ -1275,6 +1288,8 @@ KERNEL_INFOS.extend( ...@@ -1275,6 +1288,8 @@ KERNEL_INFOS.extend(
**pil_reference_pixel_difference(2, mae=True), **pil_reference_pixel_difference(2, mae=True),
**cuda_vs_cpu_pixel_difference(), **cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(), **float32_vs_uint8_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
}, },
), ),
KernelInfo( KernelInfo(
...@@ -1294,7 +1309,11 @@ KERNEL_INFOS.extend( ...@@ -1294,7 +1309,11 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.perspective_video, F.perspective_video,
sample_inputs_fn=sample_inputs_perspective_video, sample_inputs_fn=sample_inputs_perspective_video,
closeness_kwargs=cuda_vs_cpu_pixel_difference(), closeness_kwargs={
**cuda_vs_cpu_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
},
), ),
] ]
) )
......
...@@ -138,17 +138,28 @@ CONSISTENCY_CONFIGS = [ ...@@ -138,17 +138,28 @@ CONSISTENCY_CONFIGS = [
NotScriptableArgsKwargs(5, padding_mode="symmetric"), NotScriptableArgsKwargs(5, padding_mode="symmetric"),
], ],
), ),
*[
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.LinearTransformation, prototype_transforms.LinearTransformation,
legacy_transforms.LinearTransformation, legacy_transforms.LinearTransformation,
[ [
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX, LINEAR_TRANSFORMATION_MEAN), ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
], ],
# Make sure that the product of the height, width and number of channels matches the number of elements in # Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"]), make_images_kwargs=dict(
supports_pil=False, DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
), ),
supports_pil=False,
)
for matrix_dtype, image_dtype in [
(torch.float32, torch.float32),
(torch.float64, torch.float64),
(torch.float32, torch.uint8),
(torch.float64, torch.float32),
(torch.float32, torch.float64),
]
],
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Grayscale, prototype_transforms.Grayscale,
legacy_transforms.Grayscale, legacy_transforms.Grayscale,
......
...@@ -142,7 +142,7 @@ class TestKernels: ...@@ -142,7 +142,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, **kwargs), msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs),
) )
def _unbatch(self, batch, *, data_dims): def _unbatch(self, batch, *, data_dims):
......
...@@ -64,6 +64,11 @@ class LinearTransformation(Transform): ...@@ -64,6 +64,11 @@ class LinearTransformation(Transform):
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
) )
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector self.mean_vector = mean_vector
...@@ -93,7 +98,9 @@ class LinearTransformation(Transform): ...@@ -93,7 +98,9 @@ class LinearTransformation(Transform):
) )
flat_tensor = inpt.reshape(-1, n) - self.mean_vector flat_tensor = inpt.reshape(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
return transformed_tensor.reshape(shape) return transformed_tensor.reshape(shape)
......
...@@ -404,9 +404,13 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in ...@@ -404,9 +404,13 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
def _apply_grid_transform( def _apply_grid_transform(
float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT
) -> torch.Tensor: ) -> torch.Tensor:
# We are using context knowledge that grid should have float dtype
fp = img.dtype == grid.dtype
float_img = img if fp else img.to(grid.dtype)
shape = float_img.shape shape = float_img.shape
if shape[0] > 1: if shape[0] > 1:
# Apply same grid to a batch of images # Apply same grid to a batch of images
...@@ -433,7 +437,9 @@ def _apply_grid_transform( ...@@ -433,7 +437,9 @@ def _apply_grid_transform(
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img) float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
return float_img img = float_img.round_().to(img.dtype) if not fp else float_img
return img
def _assert_grid_transform_inputs( def _assert_grid_transform_inputs(
...@@ -511,7 +517,6 @@ def affine_image_tensor( ...@@ -511,7 +517,6 @@ def affine_image_tensor(
shape = image.shape shape = image.shape
ndim = image.ndim ndim = image.ndim
fp = torch.is_floating_point(image)
if ndim > 4: if ndim > 4:
image = image.reshape((-1,) + shape[-3:]) image = image.reshape((-1,) + shape[-3:])
...@@ -535,13 +540,10 @@ def affine_image_tensor( ...@@ -535,13 +540,10 @@ def affine_image_tensor(
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
dtype = image.dtype if fp else torch.float32 dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height) grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
if needs_unsquash: if needs_unsquash:
output = output.reshape(shape) output = output.reshape(shape)
...@@ -612,7 +614,7 @@ def _affine_bounding_box_xyxy( ...@@ -612,7 +614,7 @@ def _affine_bounding_box_xyxy(
# Single point structure is similar to # Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
# 2) Now let's transform the points using affine matrix # 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, transposed_affine_matrix) transformed_points = torch.matmul(points, transposed_affine_matrix)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords] # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
...@@ -797,19 +799,15 @@ def rotate_image_tensor( ...@@ -797,19 +799,15 @@ def rotate_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
if image.numel() > 0: if image.numel() > 0:
fp = torch.is_floating_point(image)
image = image.reshape(-1, num_channels, height, width) image = image.reshape(-1, num_channels, height, width)
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height) ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
dtype = image.dtype if fp else torch.float32 dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh) grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
new_height, new_width = output.shape[-2:] new_height, new_width = output.shape[-2:]
else: else:
...@@ -1237,9 +1235,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, ...@@ -1237,9 +1235,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
d = 0.5 d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device) x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
base_grid[..., 0].copy_(x_grid) base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1) y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid) base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1) base_grid[..., 2].fill_(1)
...@@ -1283,7 +1281,6 @@ def perspective_image_tensor( ...@@ -1283,7 +1281,6 @@ def perspective_image_tensor(
shape = image.shape shape = image.shape
ndim = image.ndim ndim = image.ndim
fp = torch.is_floating_point(image)
if ndim > 4: if ndim > 4:
image = image.reshape((-1,) + shape[-3:]) image = image.reshape((-1,) + shape[-3:])
...@@ -1304,12 +1301,9 @@ def perspective_image_tensor( ...@@ -1304,12 +1301,9 @@ def perspective_image_tensor(
) )
oh, ow = shape[-2:] oh, ow = shape[-2:]
dtype = image.dtype if fp else torch.float32 dtype = image.dtype if torch.is_floating_point(image) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
if needs_unsquash: if needs_unsquash:
output = output.reshape(shape) output = output.reshape(shape)
...@@ -1494,8 +1488,12 @@ def elastic_image_tensor( ...@@ -1494,8 +1488,12 @@ def elastic_image_tensor(
shape = image.shape shape = image.shape
ndim = image.ndim ndim = image.ndim
device = image.device device = image.device
fp = torch.is_floating_point(image) dtype = image.dtype if torch.is_floating_point(image) else torch.float32
# We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be casted to float32 and all computations will be done with float32
# We can fix this later if needed
if ndim > 4: if ndim > 4:
image = image.reshape((-1,) + shape[-3:]) image = image.reshape((-1,) + shape[-3:])
...@@ -1506,12 +1504,12 @@ def elastic_image_tensor( ...@@ -1506,12 +1504,12 @@ def elastic_image_tensor(
else: else:
needs_unsquash = False needs_unsquash = False
image_height, image_width = shape[-2:] if displacement.dtype != dtype or displacement.device != device:
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device)) displacement = displacement.to(dtype=dtype, device=device)
output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill)
if not fp: image_height, image_width = shape[-2:]
output = output.round_().to(image.dtype) grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash: if needs_unsquash:
output = output.reshape(shape) output = output.reshape(shape)
...@@ -1531,13 +1529,13 @@ def elastic_image_pil( ...@@ -1531,13 +1529,13 @@ def elastic_image_pil(
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch.Tensor: def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sy, sx = size sy, sx = size
base_grid = torch.empty(1, sy, sx, 2, device=device) base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device) x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
base_grid[..., 0].copy_(x_grid) base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device).unsqueeze_(-1) y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid) base_grid[..., 1].copy_(y_grid)
return base_grid return base_grid
...@@ -1552,7 +1550,11 @@ def elastic_bounding_box( ...@@ -1552,7 +1550,11 @@ def elastic_bounding_box(
return bounding_box return bounding_box
# TODO: add in docstring about approximation we are doing for grid inversion # TODO: add in docstring about approximation we are doing for grid inversion
displacement = displacement.to(bounding_box.device) device = bounding_box.device
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
if displacement.dtype != dtype or displacement.device != device:
displacement = displacement.to(dtype=dtype, device=device)
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = ( bounding_box = (
...@@ -1563,7 +1565,7 @@ def elastic_bounding_box( ...@@ -1563,7 +1565,7 @@ def elastic_bounding_box(
# Or add spatial_size arg and check displacement shape # Or add spatial_size arg and check displacement shape
spatial_size = displacement.shape[-3], displacement.shape[-2] spatial_size = displacement.shape[-3], displacement.shape[-2]
id_grid = _create_identity_grid(spatial_size, bounding_box.device) id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid # This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement) inv_grid = id_grid.sub_(displacement)
......
...@@ -1078,6 +1078,11 @@ class LinearTransformation(torch.nn.Module): ...@@ -1078,6 +1078,11 @@ class LinearTransformation(torch.nn.Module):
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
) )
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector self.mean_vector = mean_vector
...@@ -1105,9 +1110,10 @@ class LinearTransformation(torch.nn.Module): ...@@ -1105,9 +1110,10 @@ class LinearTransformation(torch.nn.Module):
) )
flat_tensor = tensor.view(-1, n) - self.mean_vector flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(shape) transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
return tensor transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
return transformed_tensor.view(shape)
def __repr__(self) -> str: def __repr__(self) -> str:
s = ( s = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment