Unverified Commit cdb6fba5 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Rewrite test and fix masks_to_boxes implementation (#4469)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 021df7a1
import os.path
import PIL.Image
import numpy
import torch
from torchvision.ops import masks_to_boxes
ASSETS_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
def test_masks_to_boxes():
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image:
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int)
for index in range(image.n_frames):
image.seek(index)
frame = numpy.array(image)
masks[index] = torch.tensor(frame)
expected = torch.tensor(
[[127, 2, 165, 40],
[2, 50, 44, 92],
[56, 63, 98, 100],
[139, 68, 175, 104],
[160, 112, 198, 145],
[49, 138, 99, 182],
[108, 148, 152, 213]],
dtype=torch.int32
)
torch.testing.assert_close(masks_to_boxes(masks), expected)
......@@ -4,7 +4,9 @@ from abc import ABC, abstractmethod
import pytest
import numpy as np
import os
from PIL import Image
import torch
from functools import lru_cache
from torch import Tensor
......@@ -1000,6 +1002,38 @@ class TestGenBoxIou:
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
class TestMasksToBoxes:
def test_masks_box(self):
def masks_box_check(masks, expected, tolerance=1e-4):
out = ops.masks_to_boxes(masks)
assert out.dtype == torch.float
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
# Check for int type boxes.
def _get_image():
assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
mask_path = os.path.join(assets_directory, "masks.tiff")
image = Image.open(mask_path)
return image
def _create_masks(image, masks):
for index in range(image.n_frames):
image.seek(index)
frame = np.array(image)
masks[index] = torch.tensor(frame)
return masks
expected = torch.tensor([[127, 2, 165, 40], [2, 50, 44, 92], [56, 63, 98, 100], [139, 68, 175, 104],
[160, 112, 198, 145], [49, 138, 99, 182], [108, 148, 152, 213]], dtype=torch.float)
image = _get_image()
for dtype in [torch.float16, torch.float32, torch.float64]:
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
masks = _create_masks(image, masks)
masks_box_check(masks, expected)
class TestStochasticDepth:
@pytest.mark.parametrize('p', [0.2, 0.5, 0.8])
@pytest.mark.parametrize('mode', ["batch", "row"])
......
......@@ -301,24 +301,24 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks
Compute the bounding boxes around the provided masks.
Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
masks (Tensor[N, H, W]): masks to transform where N is the number of
masks and (H, W) are the spatial dimensions.
masks (Tensor[N, H, W]): masks to transform where N is the number of masks
and (H, W) are the spatial dimensions.
Returns:
Tensor[N, 4]: bounding boxes
"""
if masks.numel() == 0:
return torch.zeros((0, 4))
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
n = masks.shape[0]
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int)
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
for index, mask in enumerate(masks):
y, x = torch.where(masks[index] != 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment