Unverified Commit eac3dc7b authored by Kai Zhang's avatar Kai Zhang Committed by GitHub
Browse files

Simplified usage log API (#5095)



* log API v3

* make torchscript happy

* make torchscript happy

* add missing logs to constructor

* log ops C++ API as well

* fix type hint

* check function with isinstance
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 0b02d420
import math import math
import pathlib import pathlib
import warnings import warnings
from typing import Union, Optional, List, Tuple, BinaryIO from types import FunctionType
from typing import Any, Union, Optional, List, Tuple, BinaryIO
import numpy as np import numpy as np
import torch import torch
...@@ -42,7 +43,8 @@ def make_grid( ...@@ -42,7 +43,8 @@ def make_grid(
Returns: Returns:
grid (Tensor): the tensor containing grid of images. grid (Tensor): the tensor containing grid of images.
""" """
_log_api_usage_once("utils", "make_grid") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(make_grid)
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
...@@ -88,6 +90,7 @@ def make_grid( ...@@ -88,6 +90,7 @@ def make_grid(
else: else:
norm_range(tensor, value_range) norm_range(tensor, value_range)
assert isinstance(tensor, torch.Tensor)
if tensor.size(0) == 1: if tensor.size(0) == 1:
return tensor.squeeze(0) return tensor.squeeze(0)
...@@ -131,7 +134,8 @@ def save_image( ...@@ -131,7 +134,8 @@ def save_image(
**kwargs: Other arguments are documented in ``make_grid``. **kwargs: Other arguments are documented in ``make_grid``.
""" """
_log_api_usage_once("utils", "save_image") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(save_image)
grid = make_grid(tensor, **kwargs) grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
...@@ -176,7 +180,8 @@ def draw_bounding_boxes( ...@@ -176,7 +180,8 @@ def draw_bounding_boxes(
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
""" """
_log_api_usage_once("utils", "draw_bounding_boxes") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_bounding_boxes)
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}") raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8: elif image.dtype != torch.uint8:
...@@ -255,7 +260,8 @@ def draw_segmentation_masks( ...@@ -255,7 +260,8 @@ def draw_segmentation_masks(
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
""" """
_log_api_usage_once("utils", "draw_segmentation_masks") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_segmentation_masks)
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}") raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8: elif image.dtype != torch.uint8:
...@@ -333,7 +339,8 @@ def draw_keypoints( ...@@ -333,7 +339,8 @@ def draw_keypoints(
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
""" """
_log_api_usage_once("utils", "draw_keypoints") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_keypoints)
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}") raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8: elif image.dtype != torch.uint8:
...@@ -380,7 +387,10 @@ def _generate_color_palette(num_masks: int): ...@@ -380,7 +387,10 @@ def _generate_color_palette(num_masks: int):
return [tuple((i * palette) % 255) for i in range(num_masks)] return [tuple((i * palette) % 255) for i in range(num_masks)]
def _log_api_usage_once(module: str, name: str) -> None: def _log_api_usage_once(obj: Any) -> None:
if torch.jit.is_scripting() or torch.jit.is_tracing(): if not obj.__module__.startswith("torchvision"):
return return
torch._C._log_api_usage_once(f"torchvision.{module}.{name}") name = obj.__class__.__name__
if isinstance(obj, FunctionType):
name = obj.__name__
torch._C._log_api_usage_once(f"{obj.__module__}.{name}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment