Unverified Commit d97825ea authored by AdityaKhursale's avatar AdityaKhursale Committed by GitHub
Browse files

Fix: Improve the bounding boxes implementation (#3075)



* Fix: Improve the bounding boxes implementation

Use write_png instead of PIL in test_draw_boxes()
Initialize txt_font only once

* Remove channels permutation in test_draw_boxes
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarAditya Khursale <akhursale@nvidia.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 1b00af38
...@@ -6,7 +6,7 @@ import torchvision.utils as utils ...@@ -6,7 +6,7 @@ import torchvision.utils as utils
import unittest import unittest
from io import BytesIO from io import BytesIO
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from torchvision.io.image import read_image from torchvision.io.image import read_image, write_png
from PIL import Image from PIL import Image
...@@ -90,7 +90,7 @@ class Tester(unittest.TestCase): ...@@ -90,7 +90,7 @@ class Tester(unittest.TestCase):
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
if not os.path.exists(path): if not os.path.exists(path):
Image.fromarray(result.permute(1, 2, 0).numpy()).save(path) write_png(result, path)
expected = read_image(path) expected = read_image(path)
self.assertTrue(torch.equal(result, expected)) self.assertTrue(torch.equal(result, expected))
......
...@@ -177,13 +177,13 @@ def draw_bounding_boxes( ...@@ -177,13 +177,13 @@ def draw_bounding_boxes(
img_boxes = boxes.to(torch.int64).tolist() img_boxes = boxes.to(torch.int64).tolist()
draw = ImageDraw.Draw(img_to_draw) draw = ImageDraw.Draw(img_to_draw)
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
for i, bbox in enumerate(img_boxes): for i, bbox in enumerate(img_boxes):
color = None if colors is None else colors[i] color = None if colors is None else colors[i]
draw.rectangle(bbox, width=width, outline=color) draw.rectangle(bbox, width=width, outline=color)
if labels is not None: if labels is not None:
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment