Unverified Commit f467349c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace .view with .reshape (#6777)

parent e2fa1f9d
...@@ -484,7 +484,7 @@ class AugMix(_AutoAugmentBase): ...@@ -484,7 +484,7 @@ class AugMix(_AutoAugmentBase):
orig_dims = list(image_or_video.shape) orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4 expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims) batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
...@@ -497,9 +497,9 @@ class AugMix(_AutoAugmentBase): ...@@ -497,9 +497,9 @@ class AugMix(_AutoAugmentBase):
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos. # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
combined_weights = self._sample_dirichlet( combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1]) ) * m[:, 1].reshape([batch_dims[0], -1])
mix = m[:, 0].view(batch_dims) * batch mix = m[:, 0].reshape(batch_dims) * batch
for i in range(self.mixture_width): for i in range(self.mixture_width):
aug = batch aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
...@@ -517,8 +517,8 @@ class AugMix(_AutoAugmentBase): ...@@ -517,8 +517,8 @@ class AugMix(_AutoAugmentBase):
aug = self._apply_image_or_video_transform( aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
mix.add_(combined_weights[:, i].view(batch_dims) * aug) mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image_or_video.dtype) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (features.Image, features.Video)): if isinstance(orig_image_or_video, (features.Image, features.Video)):
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type] mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
......
...@@ -88,9 +88,9 @@ class LinearTransformation(Transform): ...@@ -88,9 +88,9 @@ class LinearTransformation(Transform):
f"Got {inpt.device} vs {self.mean_vector.device}" f"Got {inpt.device} vs {self.mean_vector.device}"
) )
flat_tensor = inpt.view(-1, n) - self.mean_vector flat_tensor = inpt.reshape(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
return transformed_tensor.view(shape) return transformed_tensor.reshape(shape)
class Normalize(Transform): class Normalize(Transform):
......
...@@ -69,7 +69,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) ...@@ -69,7 +69,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
shape = image.shape shape = image.shape
if image.ndim > 4: if image.ndim > 4:
image = image.view(-1, num_channels, height, width) image = image.reshape(-1, num_channels, height, width)
needs_unsquash = True needs_unsquash = True
else: else:
needs_unsquash = False needs_unsquash = False
...@@ -77,7 +77,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) ...@@ -77,7 +77,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)
if needs_unsquash: if needs_unsquash:
output = output.view(shape) output = output.reshape(shape)
return output return output
...@@ -213,7 +213,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor: ...@@ -213,7 +213,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1) zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
lut = torch.cat([zeros, lut[:, :-1]], dim=1) lut = torch.cat([zeros, lut[:, :-1]], dim=1)
return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).view_as(img)) return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).reshape_as(img))
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
......
...@@ -38,13 +38,13 @@ def horizontal_flip_bounding_box( ...@@ -38,13 +38,13 @@ def horizontal_flip_bounding_box(
bounding_box = convert_format_bounding_box( bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).reshape(-1, 4)
bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]] bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]]
return convert_format_bounding_box( return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape) ).reshape(shape)
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
...@@ -75,13 +75,13 @@ def vertical_flip_bounding_box( ...@@ -75,13 +75,13 @@ def vertical_flip_bounding_box(
bounding_box = convert_format_bounding_box( bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).reshape(-1, 4)
bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]] bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]]
return convert_format_bounding_box( return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape) ).reshape(shape)
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
...@@ -123,7 +123,7 @@ def resize_image_tensor( ...@@ -123,7 +123,7 @@ def resize_image_tensor(
extra_dims = image.shape[:-3] extra_dims = image.shape[:-3]
if image.numel() > 0: if image.numel() > 0:
image = image.view(-1, num_channels, old_height, old_width) image = image.reshape(-1, num_channels, old_height, old_width)
image = _FT.resize( image = _FT.resize(
image, image,
...@@ -132,7 +132,7 @@ def resize_image_tensor( ...@@ -132,7 +132,7 @@ def resize_image_tensor(
antialias=antialias, antialias=antialias,
) )
return image.view(extra_dims + (num_channels, new_height, new_width)) return image.reshape(extra_dims + (num_channels, new_height, new_width))
@torch.jit.unused @torch.jit.unused
...@@ -168,7 +168,7 @@ def resize_bounding_box( ...@@ -168,7 +168,7 @@ def resize_bounding_box(
new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size) new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return ( return (
bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape), bounding_box.reshape(-1, 2, 2).mul(ratios).to(bounding_box.dtype).reshape(bounding_box.shape),
(new_height, new_width), (new_height, new_width),
) )
...@@ -270,7 +270,7 @@ def affine_image_tensor( ...@@ -270,7 +270,7 @@ def affine_image_tensor(
num_channels, height, width = image.shape[-3:] num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3] extra_dims = image.shape[:-3]
image = image.view(-1, num_channels, height, width) image = image.reshape(-1, num_channels, height, width)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
...@@ -283,7 +283,7 @@ def affine_image_tensor( ...@@ -283,7 +283,7 @@ def affine_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill) output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
return output.view(extra_dims + (num_channels, height, width)) return output.reshape(extra_dims + (num_channels, height, width))
@torch.jit.unused @torch.jit.unused
...@@ -338,20 +338,20 @@ def _affine_bounding_box_xyxy( ...@@ -338,20 +338,20 @@ def _affine_bounding_box_xyxy(
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
.view(2, 3) .reshape(2, 3)
.T .T
) )
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to # Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using affine matrix # 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, transposed_affine_matrix) transformed_points = torch.matmul(points, transposed_affine_matrix)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords] # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points: # and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2) transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1) out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1) out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
...@@ -396,7 +396,7 @@ def affine_bounding_box( ...@@ -396,7 +396,7 @@ def affine_bounding_box(
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box( bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).reshape(-1, 4)
out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center) out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center)
...@@ -404,7 +404,7 @@ def affine_bounding_box( ...@@ -404,7 +404,7 @@ def affine_bounding_box(
return convert_format_bounding_box( return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape) ).reshape(original_shape)
def affine_mask( def affine_mask(
...@@ -539,7 +539,7 @@ def rotate_image_tensor( ...@@ -539,7 +539,7 @@ def rotate_image_tensor(
if image.numel() > 0: if image.numel() > 0:
image = _FT.rotate( image = _FT.rotate(
image.view(-1, num_channels, height, width), image.reshape(-1, num_channels, height, width),
matrix, matrix,
interpolation=interpolation.value, interpolation=interpolation.value,
expand=expand, expand=expand,
...@@ -549,7 +549,7 @@ def rotate_image_tensor( ...@@ -549,7 +549,7 @@ def rotate_image_tensor(
else: else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height) new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)
return image.view(extra_dims + (num_channels, new_height, new_width)) return image.reshape(extra_dims + (num_channels, new_height, new_width))
@torch.jit.unused @torch.jit.unused
...@@ -585,7 +585,7 @@ def rotate_bounding_box( ...@@ -585,7 +585,7 @@ def rotate_bounding_box(
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box( bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).reshape(-1, 4)
out_bboxes, spatial_size = _affine_bounding_box_xyxy( out_bboxes, spatial_size = _affine_bounding_box_xyxy(
bounding_box, bounding_box,
...@@ -601,7 +601,7 @@ def rotate_bounding_box( ...@@ -601,7 +601,7 @@ def rotate_bounding_box(
return ( return (
convert_format_bounding_box( convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape), ).reshape(original_shape),
spatial_size, spatial_size,
) )
...@@ -691,7 +691,7 @@ def _pad_with_scalar_fill( ...@@ -691,7 +691,7 @@ def _pad_with_scalar_fill(
if image.numel() > 0: if image.numel() > 0:
image = _FT.pad( image = _FT.pad(
img=image.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
) )
new_height, new_width = image.shape[-2:] new_height, new_width = image.shape[-2:]
else: else:
...@@ -699,7 +699,7 @@ def _pad_with_scalar_fill( ...@@ -699,7 +699,7 @@ def _pad_with_scalar_fill(
new_height = height + top + bottom new_height = height + top + bottom
new_width = width + left + right new_width = width + left + right
return image.view(extra_dims + (num_channels, new_height, new_width)) return image.reshape(extra_dims + (num_channels, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values # TODO: This should be removed once pytorch pad supports non-scalar padding values
...@@ -714,7 +714,7 @@ def _pad_with_vector_fill( ...@@ -714,7 +714,7 @@ def _pad_with_vector_fill(
output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant") output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding) left, right, top, bottom = _parse_pad_padding(padding)
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).view(-1, 1, 1) fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
if top > 0: if top > 0:
output[..., :top, :] = fill output[..., :top, :] = fill
...@@ -863,7 +863,7 @@ def perspective_image_tensor( ...@@ -863,7 +863,7 @@ def perspective_image_tensor(
shape = image.shape shape = image.shape
if image.ndim > 4: if image.ndim > 4:
image = image.view((-1,) + shape[-3:]) image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True needs_unsquash = True
else: else:
needs_unsquash = False needs_unsquash = False
...@@ -871,7 +871,7 @@ def perspective_image_tensor( ...@@ -871,7 +871,7 @@ def perspective_image_tensor(
output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill) output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)
if needs_unsquash: if needs_unsquash:
output = output.view(shape) output = output.reshape(shape)
return output return output
...@@ -898,7 +898,7 @@ def perspective_bounding_box( ...@@ -898,7 +898,7 @@ def perspective_bounding_box(
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box( bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).reshape(-1, 4)
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device device = bounding_box.device
...@@ -947,7 +947,7 @@ def perspective_bounding_box( ...@@ -947,7 +947,7 @@ def perspective_bounding_box(
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to # Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using perspective matrices # 2) Now let's transform the points using perspective matrices
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
...@@ -959,7 +959,7 @@ def perspective_bounding_box( ...@@ -959,7 +959,7 @@ def perspective_bounding_box(
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords] # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points: # and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2) transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1) out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1) out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
...@@ -968,7 +968,7 @@ def perspective_bounding_box( ...@@ -968,7 +968,7 @@ def perspective_bounding_box(
return convert_format_bounding_box( return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape) ).reshape(original_shape)
def perspective_mask( def perspective_mask(
...@@ -1027,7 +1027,7 @@ def elastic_image_tensor( ...@@ -1027,7 +1027,7 @@ def elastic_image_tensor(
shape = image.shape shape = image.shape
if image.ndim > 4: if image.ndim > 4:
image = image.view((-1,) + shape[-3:]) image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True needs_unsquash = True
else: else:
needs_unsquash = False needs_unsquash = False
...@@ -1035,7 +1035,7 @@ def elastic_image_tensor( ...@@ -1035,7 +1035,7 @@ def elastic_image_tensor(
output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill) output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
if needs_unsquash: if needs_unsquash:
output = output.view(shape) output = output.reshape(shape)
return output return output
...@@ -1063,7 +1063,7 @@ def elastic_bounding_box( ...@@ -1063,7 +1063,7 @@ def elastic_bounding_box(
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box( bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).reshape(-1, 4)
# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
# Or add spatial_size arg and check displacement shape # Or add spatial_size arg and check displacement shape
...@@ -1075,21 +1075,21 @@ def elastic_bounding_box( ...@@ -1075,21 +1075,21 @@ def elastic_bounding_box(
inv_grid = id_grid - displacement inv_grid = id_grid - displacement
# Get points from bboxes # Get points from bboxes
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long) index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long)
index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long) index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long)
# Transform points: # Transform points:
t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype) t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5 transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5
transformed_points = transformed_points.view(-1, 4, 2) transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1) out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1) out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
return convert_format_bounding_box( return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape) ).reshape(original_shape)
def elastic_mask( def elastic_mask(
......
...@@ -65,7 +65,7 @@ def gaussian_blur_image_tensor( ...@@ -65,7 +65,7 @@ def gaussian_blur_image_tensor(
shape = image.shape shape = image.shape
if image.ndim > 4: if image.ndim > 4:
image = image.view((-1,) + shape[-3:]) image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True needs_unsquash = True
else: else:
needs_unsquash = False needs_unsquash = False
...@@ -73,7 +73,7 @@ def gaussian_blur_image_tensor( ...@@ -73,7 +73,7 @@ def gaussian_blur_image_tensor(
output = _FT.gaussian_blur(image, kernel_size, sigma) output = _FT.gaussian_blur(image, kernel_size, sigma)
if needs_unsquash: if needs_unsquash:
output = output.view(shape) output = output.reshape(shape)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment