Unverified Commit 240210c9 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add utility to draw bounding boxes (#2785)



* initital prototype

* flake

* Adds documentation

* minimal working bboxes

* Adds label display

* adds colors :-)

* adds suggestions and fixes CI

* handles image of dim 4

* fixes image handling

* removes dev file

* adds suggested changes

* Updating the API.

* Update test.

* Implementing code review improvements.

* Further refactoring and adding test.

* Replace random to white to reduce size and change font on tests.
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent b3adace6
......@@ -7,3 +7,4 @@ torchvision.utils
.. autofunction:: save_image
.. autofunction:: draw_bounding_boxes
\ No newline at end of file
......@@ -9,6 +9,7 @@ import io
import torch
import warnings
import __main__
import random
from numbers import Number
from torch._six import string_classes
......@@ -30,6 +31,12 @@ def get_tmp_dir(src=None, **kwargs):
shutil.rmtree(tmp_dir)
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
ACCEPT = os.getenv('EXPECTTEST_ACCEPT')
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
# TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job
......
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed
from collections import OrderedDict
from itertools import product
import functools
import operator
import torch
import torch.nn as nn
import numpy as np
from torchvision import models
import unittest
import random
import warnings
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_available_classification_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
......
......@@ -6,6 +6,7 @@ import torchvision.utils as utils
import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from torchvision.io.image import read_image
from PIL import Image
......@@ -79,6 +80,21 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Pixel Image not stored in file object')
def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
if not os.path.exists(path):
Image.fromarray(result.permute(1, 2, 0).numpy()).save(path)
expected = read_image(path)
self.assertTrue(torch.equal(result, expected))
if __name__ == '__main__':
unittest.main()
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import io
import pathlib
import torch
import math
import numpy as np
from PIL import Image, ImageDraw
from PIL import ImageFont
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]
irange = range
......@@ -121,10 +126,64 @@ def save_image(
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
from PIL import Image
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)
@torch.no_grad()
def draw_bounding_boxes(
image: torch.Tensor,
boxes: torch.Tensor,
labels: Optional[List[str]] = None,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
width: int = 1,
font: Optional[str] = None,
font_size: int = 10
) -> torch.Tensor:
"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.
Args:
image (Tensor): Tensor of shape (C x H x W)
bboxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
labels (List[str]): List containing the labels of bounding boxes.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can
be represented as `str` or `Tuple[int, int, int]`.
width (int): Width of bounding box.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points.
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
ndarr = image.permute(1, 2, 0).numpy()
img_to_draw = Image.fromarray(ndarr)
img_boxes = boxes.to(torch.int64).tolist()
draw = ImageDraw.Draw(img_to_draw)
for i, bbox in enumerate(img_boxes):
color = None if colors is None else colors[i]
draw.rectangle(bbox, width=width, outline=color)
if labels is not None:
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment