Unverified Commit 66d777e7 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Improved utilites, adds examples, tests (#3594)



* start adding tests

* add return type and doc

* adds tests

* add no fill tests

* add rgb test

* check inplace

* bug fix

* bug fix

* rewrite make grid

* add plotting demos

* rename file

* remove

* updt

* Add viz

* updt

* update readme, add links

* complte bounding boxes

* Complete the examples!

* link fix

* link fixed
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 20a771e5
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
[Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb) [Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/video_api.ipynb) - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/video_api.ipynb)
[Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/video_api.ipynb) [Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/video_api.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb)
[Example of Visualization Utils](https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb)
Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to
...@@ -16,3 +18,5 @@ features: ...@@ -16,3 +18,5 @@ features:
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats) - read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)
Furthermore, previously we used to provide a very high-level API for video decoding which left little control to the user. We're now expanding that API (and replacing it in the future) with a lower-level API that allows the user a frame-based access to a video. Furthermore, previously we used to provide a very high-level API for video decoding which left little control to the user. We're now expanding that API (and replacing it in the future) with a lower-level API that allows the user a frame-based access to a video.
Torchvision also provides utilities to visualize results. You can make grid of images, plot bounding boxes as well as segmentation masks. Thse utilities work standalone as well as with torchvision models for detection and segmentation.
This diff is collapsed.
...@@ -9,6 +9,9 @@ from io import BytesIO ...@@ -9,6 +9,9 @@ from io import BytesIO
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from PIL import Image from PIL import Image
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
masks = torch.tensor([ masks = torch.tensor([
[ [
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
...@@ -106,8 +109,8 @@ class Tester(unittest.TestCase): ...@@ -106,8 +109,8 @@ class Tester(unittest.TestCase):
def test_draw_boxes(self): def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8) img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], img_cp = img.clone()
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) boxes_cp = boxes.clone()
labels = ["a", "b", "c", "d"] labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"] colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True) result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
...@@ -119,9 +122,41 @@ class Tester(unittest.TestCase): ...@@ -119,9 +122,41 @@ class Tester(unittest.TestCase):
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected)) self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
def test_draw_boxes_vanilla(self):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
def test_draw_invalid_boxes(self):
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)
def test_draw_segmentation_masks_colors(self): def test_draw_segmentation_masks_colors(self):
img = torch.full((3, 5, 5), 255, dtype=torch.uint8) img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
colors = ["#FF00FF", (0, 255, 0), "red"] colors = ["#FF00FF", (0, 255, 0), "red"]
result = utils.draw_segmentation_masks(img, masks, colors=colors) result = utils.draw_segmentation_masks(img, masks, colors=colors)
...@@ -134,9 +169,14 @@ class Tester(unittest.TestCase): ...@@ -134,9 +169,14 @@ class Tester(unittest.TestCase):
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected)) self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())
def test_draw_segmentation_masks_no_colors(self): def test_draw_segmentation_masks_no_colors(self):
img = torch.full((3, 20, 20), 255, dtype=torch.uint8) img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
result = utils.draw_segmentation_masks(img, masks, colors=None) result = utils.draw_segmentation_masks(img, masks, colors=None)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
...@@ -148,6 +188,20 @@ class Tester(unittest.TestCase): ...@@ -148,6 +188,20 @@ class Tester(unittest.TestCase):
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected)) self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())
def test_draw_invalid_masks(self):
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8)
self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,8 @@ def make_grid( ...@@ -20,7 +20,8 @@ def make_grid(
pad_value: int = 0, pad_value: int = 0,
**kwargs **kwargs
) -> torch.Tensor: ) -> torch.Tensor:
"""Make a grid of images. """
Make a grid of images.
Args: Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
...@@ -37,9 +38,12 @@ def make_grid( ...@@ -37,9 +38,12 @@ def make_grid(
images separately rather than the (min, max) over all images. Default: ``False``. images separately rather than the (min, max) over all images. Default: ``False``.
pad_value (float, optional): Value for the padded pixels. Default: ``0``. pad_value (float, optional): Value for the padded pixels. Default: ``0``.
Example: Returns:
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_ grid (Tensor): the tensor containing grid of images.
Example:
See this notebook
`here <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
""" """
if not (torch.is_tensor(tensor) or if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
...@@ -117,7 +121,8 @@ def save_image( ...@@ -117,7 +121,8 @@ def save_image(
format: Optional[str] = None, format: Optional[str] = None,
**kwargs **kwargs
) -> None: ) -> None:
"""Save a given Tensor into an image file. """
Save a given Tensor into an image file.
Args: Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
...@@ -150,7 +155,7 @@ def draw_bounding_boxes( ...@@ -150,7 +155,7 @@ def draw_bounding_boxes(
""" """
Draws bounding boxes on given image. Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255. The values of the input image should be uint8 between 0 and 255.
If filled, Resulting Tensor should be saved as PNG image. If fill is True, Resulting Tensor should be saved as PNG image.
Args: Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8. image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
...@@ -166,6 +171,13 @@ def draw_bounding_boxes( ...@@ -166,6 +171,13 @@ def draw_bounding_boxes(
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points. font_size (int): The requested font size in points.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
Example:
See this notebook
`linked <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
""" """
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
...@@ -209,7 +221,7 @@ def draw_bounding_boxes( ...@@ -209,7 +221,7 @@ def draw_bounding_boxes(
if labels is not None: if labels is not None:
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
@torch.no_grad() @torch.no_grad()
...@@ -230,6 +242,13 @@ def draw_segmentation_masks( ...@@ -230,6 +242,13 @@ def draw_segmentation_masks(
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks. alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
be represented as `str` or `Tuple[int, int, int]`. be represented as `str` or `Tuple[int, int, int]`.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted.
Example:
See this notebook
`attached <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
""" """
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment