Unverified Commit 73206486 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Small optims for elastic op on bboxes (#6897)

* [proto] Small optims for elastic op on bboxes

* More inplace ops according to the review

* Create grid on device directly. This should be faster

* PR Review update. Apply ceil on float input
parent 9b0da0c3
...@@ -1108,6 +1108,18 @@ def elastic_image_pil( ...@@ -1108,6 +1108,18 @@ def elastic_image_pil(
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch.Tensor:
sy, sx = size
base_grid = torch.empty(1, sy, sx, 2, device=device)
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
return base_grid
def elastic_bounding_box( def elastic_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
...@@ -1125,22 +1137,24 @@ def elastic_bounding_box( ...@@ -1125,22 +1137,24 @@ def elastic_bounding_box(
# Or add spatial_size arg and check displacement shape # Or add spatial_size arg and check displacement shape
spatial_size = displacement.shape[-3], displacement.shape[-2] spatial_size = displacement.shape[-3], displacement.shape[-2]
id_grid = _FT._create_identity_grid(list(spatial_size)).to(bounding_box.device) id_grid = _create_identity_grid(spatial_size, bounding_box.device)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid # This is not an exact inverse of the grid
inv_grid = id_grid - displacement inv_grid = id_grid.sub_(displacement)
# Get points from bboxes # Get points from bboxes
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long) if points.is_floating_point():
index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long) points = points.ceil_()
index_xy = points.to(dtype=torch.long)
index_x, index_y = index_xy[:, 0], index_xy[:, 1]
# Transform points: # Transform points:
t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype) t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5 transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
transformed_points = transformed_points.reshape(-1, 4, 2) transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
return convert_format_bounding_box( return convert_format_bounding_box(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment