Unverified Commit cdb6fba5 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Rewrite test and fix masks_to_boxes implementation (#4469)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 021df7a1
import os.path
import PIL.Image
import numpy
import torch
from torchvision.ops import masks_to_boxes
ASSETS_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
def test_masks_to_boxes():
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image:
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int)
for index in range(image.n_frames):
image.seek(index)
frame = numpy.array(image)
masks[index] = torch.tensor(frame)
expected = torch.tensor(
[[127, 2, 165, 40],
[2, 50, 44, 92],
[56, 63, 98, 100],
[139, 68, 175, 104],
[160, 112, 198, 145],
[49, 138, 99, 182],
[108, 148, 152, 213]],
dtype=torch.int32
)
torch.testing.assert_close(masks_to_boxes(masks), expected)
...@@ -4,7 +4,9 @@ from abc import ABC, abstractmethod ...@@ -4,7 +4,9 @@ from abc import ABC, abstractmethod
import pytest import pytest
import numpy as np import numpy as np
import os
from PIL import Image
import torch import torch
from functools import lru_cache from functools import lru_cache
from torch import Tensor from torch import Tensor
...@@ -1000,6 +1002,38 @@ class TestGenBoxIou: ...@@ -1000,6 +1002,38 @@ class TestGenBoxIou:
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3) gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
class TestMasksToBoxes:
def test_masks_box(self):
def masks_box_check(masks, expected, tolerance=1e-4):
out = ops.masks_to_boxes(masks)
assert out.dtype == torch.float
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
# Check for int type boxes.
def _get_image():
assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
mask_path = os.path.join(assets_directory, "masks.tiff")
image = Image.open(mask_path)
return image
def _create_masks(image, masks):
for index in range(image.n_frames):
image.seek(index)
frame = np.array(image)
masks[index] = torch.tensor(frame)
return masks
expected = torch.tensor([[127, 2, 165, 40], [2, 50, 44, 92], [56, 63, 98, 100], [139, 68, 175, 104],
[160, 112, 198, 145], [49, 138, 99, 182], [108, 148, 152, 213]], dtype=torch.float)
image = _get_image()
for dtype in [torch.float16, torch.float32, torch.float64]:
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
masks = _create_masks(image, masks)
masks_box_check(masks, expected)
class TestStochasticDepth: class TestStochasticDepth:
@pytest.mark.parametrize('p', [0.2, 0.5, 0.8]) @pytest.mark.parametrize('p', [0.2, 0.5, 0.8])
@pytest.mark.parametrize('mode', ["batch", "row"]) @pytest.mark.parametrize('mode', ["batch", "row"])
......
...@@ -301,24 +301,24 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: ...@@ -301,24 +301,24 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
""" """
Compute the bounding boxes around the provided masks Compute the bounding boxes around the provided masks.
Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``. ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args: Args:
masks (Tensor[N, H, W]): masks to transform where N is the number of masks (Tensor[N, H, W]): masks to transform where N is the number of masks
masks and (H, W) are the spatial dimensions. and (H, W) are the spatial dimensions.
Returns: Returns:
Tensor[N, 4]: bounding boxes Tensor[N, 4]: bounding boxes
""" """
if masks.numel() == 0: if masks.numel() == 0:
return torch.zeros((0, 4)) return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
n = masks.shape[0] n = masks.shape[0]
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int) bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
for index, mask in enumerate(masks): for index, mask in enumerate(masks):
y, x = torch.where(masks[index] != 0) y, x = torch.where(masks[index] != 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment