Unverified Commit 240210c9 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add utility to draw bounding boxes (#2785)



* initital prototype

* flake

* Adds documentation

* minimal working bboxes

* Adds label display

* adds colors :-)

* adds suggestions and fixes CI

* handles image of dim 4

* fixes image handling

* removes dev file

* adds suggested changes

* Updating the API.

* Update test.

* Implementing code review improvements.

* Further refactoring and adding test.

* Replace random to white to reduce size and change font on tests.
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent b3adace6
...@@ -7,3 +7,4 @@ torchvision.utils ...@@ -7,3 +7,4 @@ torchvision.utils
.. autofunction:: save_image .. autofunction:: save_image
.. autofunction:: draw_bounding_boxes
\ No newline at end of file
...@@ -9,6 +9,7 @@ import io ...@@ -9,6 +9,7 @@ import io
import torch import torch
import warnings import warnings
import __main__ import __main__
import random
from numbers import Number from numbers import Number
from torch._six import string_classes from torch._six import string_classes
...@@ -30,6 +31,12 @@ def get_tmp_dir(src=None, **kwargs): ...@@ -30,6 +31,12 @@ def get_tmp_dir(src=None, **kwargs):
shutil.rmtree(tmp_dir) shutil.rmtree(tmp_dir)
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
ACCEPT = os.getenv('EXPECTTEST_ACCEPT') ACCEPT = os.getenv('EXPECTTEST_ACCEPT')
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
# TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job # TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job
......
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed
from collections import OrderedDict from collections import OrderedDict
from itertools import product from itertools import product
import functools import functools
import operator import operator
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from torchvision import models from torchvision import models
import unittest import unittest
import random
import warnings import warnings
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_available_classification_models(): def get_available_classification_models():
# TODO add a registration mechanism to torchvision.models # TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
......
...@@ -6,6 +6,7 @@ import torchvision.utils as utils ...@@ -6,6 +6,7 @@ import torchvision.utils as utils
import unittest import unittest
from io import BytesIO from io import BytesIO
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from torchvision.io.image import read_image
from PIL import Image from PIL import Image
...@@ -79,6 +80,21 @@ class Tester(unittest.TestCase): ...@@ -79,6 +80,21 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Pixel Image not stored in file object') 'Pixel Image not stored in file object')
def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
if not os.path.exists(path):
Image.fromarray(result.permute(1, 2, 0).numpy()).save(path)
expected = read_image(path)
self.assertTrue(torch.equal(result, expected))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
from typing import Union, Optional, List, Tuple, Text, BinaryIO from typing import Union, Optional, List, Tuple, Text, BinaryIO
import io
import pathlib import pathlib
import torch import torch
import math import math
import numpy as np
from PIL import Image, ImageDraw
from PIL import ImageFont
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]
irange = range irange = range
...@@ -121,10 +126,64 @@ def save_image( ...@@ -121,10 +126,64 @@ def save_image(
If a file object was used instead of a filename, this parameter should always be used. If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``. **kwargs: Other arguments are documented in ``make_grid``.
""" """
from PIL import Image
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each) normalize=normalize, range=range, scale_each=scale_each)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr) im = Image.fromarray(ndarr)
im.save(fp, format=format) im.save(fp, format=format)
@torch.no_grad()
def draw_bounding_boxes(
image: torch.Tensor,
boxes: torch.Tensor,
labels: Optional[List[str]] = None,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
width: int = 1,
font: Optional[str] = None,
font_size: int = 10
) -> torch.Tensor:
"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.
Args:
image (Tensor): Tensor of shape (C x H x W)
bboxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
labels (List[str]): List containing the labels of bounding boxes.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can
be represented as `str` or `Tuple[int, int, int]`.
width (int): Width of bounding box.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points.
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
ndarr = image.permute(1, 2, 0).numpy()
img_to_draw = Image.fromarray(ndarr)
img_boxes = boxes.to(torch.int64).tolist()
draw = ImageDraw.Draw(img_to_draw)
for i, bbox in enumerate(img_boxes):
color = None if colors is None else colors[i]
draw.rectangle(bbox, width=width, outline=color)
if labels is not None:
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment