Unverified Commit 96640af0 authored by G's avatar G Committed by GitHub
Browse files

add float support to `utils.draw_bounding_boxes()` (#8328)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 0367c219
...@@ -116,6 +116,23 @@ def test_draw_boxes(): ...@@ -116,6 +116,23 @@ def test_draw_boxes():
assert_equal(img, img_cp) assert_equal(img, img_cp)
@pytest.mark.parametrize("fill", [True, False])
def test_draw_boxes_dtypes(fill):
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill)
assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8
img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill)
assert img_float is not out_float
assert out_float.is_floating_point()
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)]) @pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
def test_draw_boxes_colors(colors): def test_draw_boxes_colors(colors):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
...@@ -152,7 +169,6 @@ def test_draw_boxes_grayscale(): ...@@ -152,7 +169,6 @@ def test_draw_boxes_grayscale():
def test_draw_invalid_boxes(): def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3)) img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8) img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
...@@ -162,8 +178,6 @@ def test_draw_invalid_boxes(): ...@@ -162,8 +178,6 @@ def test_draw_invalid_boxes():
with pytest.raises(TypeError, match="Tensor expected"): with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes) utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"): with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes) utils.draw_bounding_boxes(img_wrong2, boxes)
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"): with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
......
...@@ -164,12 +164,12 @@ def draw_bounding_boxes( ...@@ -164,12 +164,12 @@ def draw_bounding_boxes(
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Draws bounding boxes on given image. Draws bounding boxes on given RGB image.
The values of the input image should be uint8 between 0 and 255. The image values should be uint8 in [0, 255] or float in [0, 1].
If fill is True, Resulting Tensor should be saved as PNG image. If fill is True, Resulting Tensor should be saved as PNG image.
Args: Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8. image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float.
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`. `0 <= ymin < ymax < H`.
...@@ -188,13 +188,14 @@ def draw_bounding_boxes( ...@@ -188,13 +188,14 @@ def draw_bounding_boxes(
Returns: Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
""" """
import torchvision.transforms.v2.functional as F # noqa
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_bounding_boxes) _log_api_usage_once(draw_bounding_boxes)
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}") raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8: elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"Tensor uint8 expected, got {image.dtype}") raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3: elif image.dim() != 3:
raise ValueError("Pass individual images, not batches") raise ValueError("Pass individual images, not batches")
elif image.size(0) not in {1, 3}: elif image.size(0) not in {1, 3}:
...@@ -230,8 +231,11 @@ def draw_bounding_boxes( ...@@ -230,8 +231,11 @@ def draw_bounding_boxes(
if image.size(0) == 1: if image.size(0) == 1:
image = torch.tile(image, (3, 1, 1)) image = torch.tile(image, (3, 1, 1))
ndarr = image.permute(1, 2, 0).cpu().numpy() original_dtype = image.dtype
img_to_draw = Image.fromarray(ndarr) if original_dtype.is_floating_point:
image = F.to_dtype(image, dtype=torch.uint8, scale=True)
img_to_draw = F.to_pil_image(image)
img_boxes = boxes.to(torch.int64).tolist() img_boxes = boxes.to(torch.int64).tolist()
if fill: if fill:
...@@ -250,7 +254,10 @@ def draw_bounding_boxes( ...@@ -250,7 +254,10 @@ def draw_bounding_boxes(
margin = width + 1 margin = width + 1
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) out = F.pil_to_tensor(img_to_draw)
if original_dtype.is_floating_point:
out = F.to_dtype(out, dtype=original_dtype, scale=True)
return out
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment