Unverified Commit c34a9145 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Add API usage calls to utils (#5077)

* Add API usage calls to utils

* Update to the new api
parent 1b14829c
...@@ -42,6 +42,7 @@ def make_grid( ...@@ -42,6 +42,7 @@ def make_grid(
Returns: Returns:
grid (Tensor): the tensor containing grid of images. grid (Tensor): the tensor containing grid of images.
""" """
_log_api_usage_once("utils", "make_grid")
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
...@@ -130,6 +131,7 @@ def save_image( ...@@ -130,6 +131,7 @@ def save_image(
**kwargs: Other arguments are documented in ``make_grid``. **kwargs: Other arguments are documented in ``make_grid``.
""" """
_log_api_usage_once("utils", "save_image")
grid = make_grid(tensor, **kwargs) grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
...@@ -174,6 +176,7 @@ def draw_bounding_boxes( ...@@ -174,6 +176,7 @@ def draw_bounding_boxes(
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
""" """
_log_api_usage_once("utils", "draw_bounding_boxes")
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}") raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8: elif image.dtype != torch.uint8:
...@@ -252,6 +255,7 @@ def draw_segmentation_masks( ...@@ -252,6 +255,7 @@ def draw_segmentation_masks(
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
""" """
_log_api_usage_once("utils", "draw_segmentation_masks")
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}") raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8: elif image.dtype != torch.uint8:
...@@ -329,6 +333,7 @@ def draw_keypoints( ...@@ -329,6 +333,7 @@ def draw_keypoints(
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
""" """
_log_api_usage_once("utils", "draw_keypoints")
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}") raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8: elif image.dtype != torch.uint8:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment