You need to sign in or sign up before continuing.
Unverified Commit 738fa133 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Throw ValueError in draw bounding boxes for invalid boxes (#6123)

* Fix the issue :)

* Intellij vs ufmt battle

* remove .item()
parent c890a7e7
...@@ -152,6 +152,7 @@ def test_draw_invalid_boxes(): ...@@ -152,6 +152,7 @@ def test_draw_invalid_boxes():
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8) img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
boxes_wrong = torch.tensor([[10, 10, 4, 5], [30, 20, 10, 5]], dtype=torch.float)
labels_wrong = ["one", "two"] labels_wrong = ["one", "two"]
colors_wrong = ["pink", "blue"] colors_wrong = ["pink", "blue"]
...@@ -167,6 +168,8 @@ def test_draw_invalid_boxes(): ...@@ -167,6 +168,8 @@ def test_draw_invalid_boxes():
utils.draw_bounding_boxes(img_correct, boxes, labels_wrong) utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
with pytest.raises(ValueError, match="Number of colors"): with pytest.raises(ValueError, match="Number of colors"):
utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong) utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
with pytest.raises(ValueError, match="Boxes need to be in"):
utils.draw_bounding_boxes(img_correct, boxes_wrong)
def test_draw_boxes_warning(): def test_draw_boxes_warning():
......
...@@ -208,6 +208,10 @@ def draw_bounding_boxes( ...@@ -208,6 +208,10 @@ def draw_bounding_boxes(
raise ValueError("Pass individual images, not batches") raise ValueError("Pass individual images, not batches")
elif image.size(0) not in {1, 3}: elif image.size(0) not in {1, 3}:
raise ValueError("Only grayscale and RGB images are supported") raise ValueError("Only grayscale and RGB images are supported")
elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any():
raise ValueError(
"Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
)
num_boxes = boxes.shape[0] num_boxes = boxes.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment