Unverified Commit 3161841d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Better handling of grayscale images in draw bbox (#4049)

* Adding a test for grayscale images.

* Doing the ops in torch, adding more checks.
parent 26d6080f
......@@ -131,6 +131,13 @@ def test_draw_boxes_vanilla():
assert_equal(img, img_cp)
def test_draw_boxes_grayscale():
img = torch.full((1, 4, 4), fill_value=255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 3, 3]], dtype=torch.int64)
bboxed_img = utils.draw_bounding_boxes(image=img, boxes=boxes, colors=["#1BBC9B"])
assert bboxed_img.size(0) == 3
def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
......@@ -143,6 +150,8 @@ def test_draw_invalid_boxes():
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes)
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
@pytest.mark.parametrize('colors', [
......
......@@ -178,12 +178,13 @@ def draw_bounding_boxes(
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size(0) not in {1, 3}:
raise ValueError("Only grayscale and RGB images are supported")
if image.size(0) == 1:
image = torch.tile(image, (3, 1, 1))
ndarr = image.permute(1, 2, 0).numpy()
# allow single-channel-images
# shape: (1, H, W) with C = 1
if ndarr.shape[-1] == 1:
ndarr = np.tile(ndarr, (1, 1, 3))
img_to_draw = Image.fromarray(ndarr)
img_boxes = boxes.to(torch.int64).tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment