Unverified Commit ed081770 authored by wushufan's avatar wushufan Committed by GitHub
Browse files

[Doc]: Update docstrings and typehint for modules in models (#2146)

* Add typehint for models/utils and others

* fix line too long

* fix format error by pre-commit

* fix format error by pre-commit

* update with review suggestions
parent 32ab994d
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch import torch
from mmdet3d.structures import bbox3d2result, bbox3d_mapping_back, xywhr2xyxyr from mmdet3d.structures import bbox3d2result, bbox3d_mapping_back, xywhr2xyxyr
from mmdet3d.utils import ConfigType
from ..layers import nms_bev, nms_normal_bev from ..layers import nms_bev, nms_normal_bev
def merge_aug_bboxes_3d(aug_results, aug_batch_input_metas, test_cfg): def merge_aug_bboxes_3d(aug_results: List[dict],
aug_batch_input_metas: List[dict],
test_cfg: ConfigType) -> dict:
"""Merge augmented detection 3D bboxes and scores. """Merge augmented detection 3D bboxes and scores.
Args: Args:
aug_results (list[dict]): The dict of detection results. aug_results (List[dict]): The dict of detection results.
The dict contains the following keys The dict contains the following keys
- bbox_3d (:obj:`BaseInstance3DBoxes`): Detection bbox. - bbox_3d (:obj:`BaseInstance3DBoxes`): Detection bbox.
- scores_3d (torch.Tensor): Detection scores. - scores_3d (Tensor): Detection scores.
- labels_3d (torch.Tensor): Predicted box labels. - labels_3d (Tensor): Predicted box labels.
img_metas (list[dict]): Meta information of each sample. aug_batch_input_metas (List[dict]): Meta information of each sample.
test_cfg (dict): Test config. test_cfg (dict or :obj:`ConfigDict`): Test config.
Returns: Returns:
dict: Bounding boxes results in cpu mode, containing merged results. dict: Bounding boxes results in cpu mode, containing merged results.
...@@ -27,9 +32,9 @@ def merge_aug_bboxes_3d(aug_results, aug_batch_input_metas, test_cfg): ...@@ -27,9 +32,9 @@ def merge_aug_bboxes_3d(aug_results, aug_batch_input_metas, test_cfg):
""" """
assert len(aug_results) == len(aug_batch_input_metas), \ assert len(aug_results) == len(aug_batch_input_metas), \
'"aug_results" should have the same length as "img_metas", got len(' \ '"aug_results" should have the same length as ' \
f'aug_results)={len(aug_results)} and ' \ f'"aug_batch_input_metas", got len(aug_results)={len(aug_results)} ' \
f'len(img_metas)={len(aug_batch_input_metas)}' f'and len(aug_batch_input_metas)={len(aug_batch_input_metas)}'
recovered_bboxes = [] recovered_bboxes = []
recovered_scores = [] recovered_scores = []
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
def add_prefix(inputs, prefix): def add_prefix(inputs: dict, prefix: str) -> dict:
"""Add prefix for dict. """Add prefix for dict.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from torch import Tensor
def clip_sigmoid(x, eps=1e-4): def clip_sigmoid(x: Tensor, eps: float = 1e-4) -> Tensor:
"""Sigmoid function for input feature. """Sigmoid function for input feature.
Args: Args:
x (torch.Tensor): Input feature map with the shape of [B, N, H, W]. x (Tensor): Input feature map with the shape of [B, N, H, W].
eps (float, optional): Lower bound of the range to be clamped to. eps (float): Lower bound of the range to be clamped to.
Defaults to 1e-4. Defaults to 1e-4.
Returns: Returns:
torch.Tensor: Feature map after sigmoid. Tensor: Feature map after sigmoid.
""" """
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps) y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
return y return y
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import numpy as np import numpy as np
import torch import torch
from torch import Tensor
def get_edge_indices(img_metas, def get_edge_indices(img_metas: List[dict],
downsample_ratio, downsample_ratio: int,
step=1, step: int = 1,
pad_mode='default', pad_mode: str = 'default',
dtype=np.float32, dtype: type = np.float32,
device='cpu'): device: str = 'cpu') -> List[Tensor]:
"""Function to filter the objects label outside the image. """Function to filter the objects label outside the image.
The edge_indices are generated using numpy on cpu rather The edge_indices are generated using numpy on cpu rather
than on CUDA due to the latency issue. When batch size = 8, than on CUDA due to the latency issue. When batch size = 8,
...@@ -16,20 +19,20 @@ def get_edge_indices(img_metas, ...@@ -16,20 +19,20 @@ def get_edge_indices(img_metas,
with CUDA tensor (0.09s and 0.72s in 100 runs). with CUDA tensor (0.09s and 0.72s in 100 runs).
Args: Args:
img_metas (list[dict]): Meta information of each image, e.g., img_metas (List[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
downsample_ratio (int): Downsample ratio of output feature, downsample_ratio (int): Downsample ratio of output feature,
step (int, optional): Step size used for generateing step (int): Step size used for generateing
edge indices. Default: 1. edge indices. Defaults to 1.
pad_mode (str, optional): Padding mode during data pipeline. pad_mode (str): Padding mode during data pipeline.
Default: 'default'. Defaults to 'default'.
dtype (torch.dtype, optional): Dtype of edge indices tensor. dtype (type): Dtype of edge indices tensor.
Default: np.float32. Defaults to np.float32.
device (str, optional): Device of edge indices tensor. device (str): Device of edge indices tensor.
Default: 'cpu'. Defaults to 'cpu'.
Returns: Returns:
list[Tensor]: Edge indices for each image in batch data. List[Tensor]: Edge indices for each image in batch data.
""" """
edge_indices_list = [] edge_indices_list = []
for i in range(len(img_metas)): for i in range(len(img_metas)):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import Tensor
def gaussian_2d(shape, sigma=1): def gaussian_2d(shape: Tuple[int, int], sigma: float = 1) -> np.ndarray:
"""Generate gaussian map. """Generate gaussian map.
Args: Args:
shape (list[int]): Shape of the map. shape (Tuple[int]): Shape of the map.
sigma (float, optional): Sigma to generate gaussian map. sigma (float): Sigma to generate gaussian map.
Defaults to 1. Defaults to 1.
Returns: Returns:
...@@ -22,17 +25,20 @@ def gaussian_2d(shape, sigma=1): ...@@ -22,17 +25,20 @@ def gaussian_2d(shape, sigma=1):
return h return h
def draw_heatmap_gaussian(heatmap, center, radius, k=1): def draw_heatmap_gaussian(heatmap: Tensor,
center: Tensor,
radius: int,
k: int = 1) -> Tensor:
"""Get gaussian masked heatmap. """Get gaussian masked heatmap.
Args: Args:
heatmap (torch.Tensor): Heatmap to be masked. heatmap (Tensor): Heatmap to be masked.
center (torch.Tensor): Center coord of the heatmap. center (Tensor): Center coord of the heatmap.
radius (int): Radius of gaussian. radius (int): Radius of gaussian.
K (int, optional): Multiple of masked_gaussian. Defaults to 1. k (int): Multiple of masked_gaussian. Defaults to 1.
Returns: Returns:
torch.Tensor: Masked heatmap. Tensor: Masked heatmap.
""" """
diameter = 2 * radius + 1 diameter = 2 * radius + 1
gaussian = gaussian_2d((diameter, diameter), sigma=diameter / 6) gaussian = gaussian_2d((diameter, diameter), sigma=diameter / 6)
...@@ -54,15 +60,16 @@ def draw_heatmap_gaussian(heatmap, center, radius, k=1): ...@@ -54,15 +60,16 @@ def draw_heatmap_gaussian(heatmap, center, radius, k=1):
return heatmap return heatmap
def gaussian_radius(det_size, min_overlap=0.5): def gaussian_radius(det_size: Tuple[Tensor, Tensor],
min_overlap: float = 0.5) -> Tensor:
"""Get radius of gaussian. """Get radius of gaussian.
Args: Args:
det_size (tuple[torch.Tensor]): Size of the detection result. det_size (Tuple[Tensor]): Size of the detection result.
min_overlap (float, optional): Gaussian_overlap. Defaults to 0.5. min_overlap (float): Gaussian_overlap. Defaults to 0.5.
Returns: Returns:
torch.Tensor: Computed radius. Tensor: Computed radius.
""" """
height, width = det_size height, width = det_size
...@@ -86,24 +93,28 @@ def gaussian_radius(det_size, min_overlap=0.5): ...@@ -86,24 +93,28 @@ def gaussian_radius(det_size, min_overlap=0.5):
return min(r1, r2, r3) return min(r1, r2, r3)
def get_ellip_gaussian_2D(heatmap, center, radius_x, radius_y, k=1): def get_ellip_gaussian_2D(heatmap: Tensor,
center: List[int],
radius_x: int,
radius_y: int,
k: int = 1) -> Tensor:
"""Generate 2D ellipse gaussian heatmap. """Generate 2D ellipse gaussian heatmap.
Args: Args:
heatmap (Tensor): Input heatmap, the gaussian kernel will cover on heatmap (Tensor): Input heatmap, the gaussian kernel will cover on
it and maintain the max value. it and maintain the max value.
center (list[int]): Coord of gaussian kernel's center. center (List[int]): Coord of gaussian kernel's center.
radius_x (int): X-axis radius of gaussian kernel. radius_x (int): X-axis radius of gaussian kernel.
radius_y (int): Y-axis radius of gaussian kernel. radius_y (int): Y-axis radius of gaussian kernel.
k (int, optional): Coefficient of gaussian kernel. Default: 1. k (int): Coefficient of gaussian kernel. Defaults to 1.
Returns: Returns:
out_heatmap (Tensor): Updated heatmap covered by gaussian kernel. out_heatmap (Tensor): Updated heatmap covered by gaussian kernel.
""" """
diameter_x, diameter_y = 2 * radius_x + 1, 2 * radius_y + 1 diameter_x, diameter_y = 2 * radius_x + 1, 2 * radius_y + 1
gaussian_kernel = ellip_gaussian2D((radius_x, radius_y), gaussian_kernel = ellip_gaussian2D((radius_x, radius_y),
sigma_x=diameter_x / 6, sigma_x=diameter_x // 6,
sigma_y=diameter_y / 6, sigma_y=diameter_y // 6,
dtype=heatmap.dtype, dtype=heatmap.dtype,
device=heatmap.device) device=heatmap.device)
...@@ -125,22 +136,22 @@ def get_ellip_gaussian_2D(heatmap, center, radius_x, radius_y, k=1): ...@@ -125,22 +136,22 @@ def get_ellip_gaussian_2D(heatmap, center, radius_x, radius_y, k=1):
return out_heatmap return out_heatmap
def ellip_gaussian2D(radius, def ellip_gaussian2D(radius: Tuple[int, int],
sigma_x, sigma_x: int,
sigma_y, sigma_y: int,
dtype=torch.float32, dtype: torch.dtype = torch.float32,
device='cpu'): device: str = 'cpu') -> Tensor:
"""Generate 2D ellipse gaussian kernel. """Generate 2D ellipse gaussian kernel.
Args: Args:
radius (tuple(int)): Ellipse radius (radius_x, radius_y) of gaussian radius (Tuple[int]): Ellipse radius (radius_x, radius_y) of gaussian
kernel. kernel.
sigma_x (int): X-axis sigma of gaussian function. sigma_x (int): X-axis sigma of gaussian function.
sigma_y (int): Y-axis sigma of gaussian function. sigma_y (int): Y-axis sigma of gaussian function.
dtype (torch.dtype, optional): Dtype of gaussian tensor. dtype (torch.dtype): Dtype of gaussian tensor.
Default: torch.float32. Defaults to torch.float32.
device (str, optional): Device of gaussian tensor. device (str): Device of gaussian tensor.
Default: 'cpu'. Defaults to 'cpu'.
Returns: Returns:
h (Tensor): Gaussian kernel with a h (Tensor): Gaussian kernel with a
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch import torch
from torch import Tensor
from mmdet3d.structures import points_cam2img from mmdet3d.structures import CameraInstance3DBoxes, points_cam2img
def get_keypoints(gt_bboxes_3d_list, def get_keypoints(
centers2d_list, gt_bboxes_3d_list: List[CameraInstance3DBoxes],
img_metas, centers2d_list: List[Tensor],
use_local_coords=True): img_metas: List[dict],
use_local_coords: bool = True) -> Tuple[List[Tensor], List[Tensor]]:
"""Function to filter the objects label outside the image. """Function to filter the objects label outside the image.
Args: Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, gt_bboxes_3d_list (List[:obj:`CameraInstance3DBoxes`]): Ground truth
shape (num_gt, 4). bboxes of each image.
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image, centers2d_list (List[Tensor]): Projected 3D centers onto 2D image,
shape (num_gt, 2). shape (num_gt, 2).
img_metas (list[dict]): Meta information of each image, e.g., img_metas (List[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
use_local_coords (bool, optional): Wheher to use local coordinates use_local_coords (bool): Whether to use local coordinates
for keypoints. Default: True. for keypoints. Defaults to True.
Returns: Returns:
tuple[list[Tensor]]: It contains two elements, the first is the Tuple[List[Tensor], List[Tensor]]: It contains two elements,
keypoints for each projected 2D bbox in batch data. The second is the first is the keypoints for each projected 2D bbox in batch data.
the visible mask of depth calculated by keypoints. The second is the visible mask of depth calculated by keypoints.
""" """
assert len(gt_bboxes_3d_list) == len(centers2d_list) assert len(gt_bboxes_3d_list) == len(centers2d_list)
...@@ -56,8 +60,8 @@ def get_keypoints(gt_bboxes_3d_list, ...@@ -56,8 +60,8 @@ def get_keypoints(gt_bboxes_3d_list,
keypoints_z_visible = (keypoints3d[..., -1] > 0) keypoints_z_visible = (keypoints3d[..., -1] > 0)
# (N, 1O) # (N, 1O)
keypoints_visible = keypoints_x_visible & \ keypoints_visible = \
keypoints_y_visible & keypoints_z_visible keypoints_x_visible & keypoints_y_visible & keypoints_z_visible
# center, diag-02, diag-13 # center, diag-02, diag-13
keypoints_depth_valid = torch.stack( keypoints_depth_valid = torch.stack(
(keypoints_visible[:, [8, 9]].all(dim=1), (keypoints_visible[:, [8, 9]].all(dim=1),
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch import torch
from torch import Tensor
from mmdet3d.structures import CameraInstance3DBoxes
def filter_outside_objs(gt_bboxes_list, gt_labels_list, gt_bboxes_3d_list, def filter_outside_objs(gt_bboxes_list: List[Tensor],
gt_labels_3d_list, centers2d_list, img_metas): gt_labels_list: List[Tensor],
gt_bboxes_3d_list: List[CameraInstance3DBoxes],
gt_labels_3d_list: List[Tensor],
centers2d_list: List[Tensor],
img_metas: List[dict]) -> None:
"""Function to filter the objects label outside the image. """Function to filter the objects label outside the image.
Args: Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, gt_bboxes_list (List[Tensor]): Ground truth bboxes of each image,
each has shape (num_gt, 4). each has shape (num_gt, 4).
gt_labels_list (list[Tensor]): Ground truth labels of each box, gt_labels_list (List[Tensor]): Ground truth labels of each box,
each has shape (num_gt,). each has shape (num_gt,).
gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each gt_bboxes_3d_list (List[:obj:`CameraInstance3DBoxes`]): 3D Ground
image, each has shape (num_gt, bbox_code_size). truth bboxes of each image, each has shape
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each (num_gt, bbox_code_size).
gt_labels_3d_list (List[Tensor]): 3D Ground truth labels of each
box, each has shape (num_gt,). box, each has shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image, centers2d_list (List[Tensor]): Projected 3D centers onto 2D image,
each has shape (num_gt, 2). each has shape (num_gt, 2).
img_metas (list[dict]): Meta information of each image, e.g., img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
...@@ -36,7 +46,8 @@ def filter_outside_objs(gt_bboxes_list, gt_labels_list, gt_bboxes_3d_list, ...@@ -36,7 +46,8 @@ def filter_outside_objs(gt_bboxes_list, gt_labels_list, gt_bboxes_3d_list,
gt_labels_3d_list[i] = gt_labels_3d_list[i][keep_inds] gt_labels_3d_list[i] = gt_labels_3d_list[i][keep_inds]
def get_centers2d_target(centers2d, centers, img_shape): def get_centers2d_target(centers2d: Tensor, centers: Tensor,
img_shape: tuple) -> Tensor:
"""Function to get target centers2d. """Function to get target centers2d.
Args: Args:
...@@ -80,24 +91,27 @@ def get_centers2d_target(centers2d, centers, img_shape): ...@@ -80,24 +91,27 @@ def get_centers2d_target(centers2d, centers, img_shape):
return centers2d_target return centers2d_target
def handle_proj_objs(centers2d_list, gt_bboxes_list, img_metas): def handle_proj_objs(
centers2d_list: List[Tensor], gt_bboxes_list: List[Tensor],
img_metas: List[dict]
) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
"""Function to handle projected object centers2d, generate target """Function to handle projected object centers2d, generate target
centers2d. centers2d.
Args: Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, gt_bboxes_list (List[Tensor]): Ground truth bboxes of each image,
shape (num_gt, 4). shape (num_gt, 4).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image, centers2d_list (List[Tensor]): Projected 3D centers onto 2D image,
shape (num_gt, 2). shape (num_gt, 2).
img_metas (list[dict]): Meta information of each image, e.g., img_metas (List[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
Returns: Returns:
tuple[list[Tensor]]: It contains three elements. The first is the Tuple[List[Tensor], List[Tensor], List[Tensor]]: It contains three
target centers2d after handling the truncated objects. The second elements. The first is the target centers2d after handling the
is the offsets between target centers2d and round int dtype truncated objects. The second is the offsets between target centers2d
centers2d,and the last is the truncation mask for each object in and round int dtype centers2d,and the last is the truncation mask
batch data. for each object in batch data.
""" """
bs = len(centers2d_list) bs = len(centers2d_list)
centers2d_target_list = [] centers2d_target_list = []
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import functools import functools
from inspect import getfullargspec from inspect import getfullargspec
from typing import Callable, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
TemplateArrayType = Union[tuple, list, int, float, np.ndarray, torch.Tensor]
OptArrayType = Optional[Union[np.ndarray, torch.Tensor]]
def array_converter(to_torch=True,
apply_to=tuple(), def array_converter(to_torch: bool = True,
template_arg_name_=None, apply_to: Tuple[str, ...] = tuple(),
recover=True): template_arg_name_: Optional[str] = None,
recover: bool = True) -> Callable:
"""Wrapper function for data-type agnostic processing. """Wrapper function for data-type agnostic processing.
First converts input arrays to PyTorch tensors or NumPy ndarrays First converts input arrays to PyTorch tensors or NumPy ndarrays
...@@ -17,15 +21,15 @@ def array_converter(to_torch=True, ...@@ -17,15 +21,15 @@ def array_converter(to_torch=True,
`recover=True`. `recover=True`.
Args: Args:
to_torch (Bool, optional): Whether convert to PyTorch tensors to_torch (bool): Whether convert to PyTorch tensors
for middle calculation. Defaults to True. for middle calculation. Defaults to True.
apply_to (tuple[str], optional): The arguments to which we apply apply_to (Tuple[str, ...]): The arguments to which we apply
data-type conversion. Defaults to an empty tuple. data-type conversion. Defaults to an empty tuple.
template_arg_name_ (str, optional): Argument serving as the template ( template_arg_name_ (str, optional): Argument serving as the template (
return arrays should have the same dtype and device return arrays should have the same dtype and device
as the template). Defaults to None. If None, we will use the as the template). Defaults to None. If None, we will use the
first argument in `apply_to` as the template argument. first argument in `apply_to` as the template argument.
recover (Bool, optional): Whether or not recover the wrapped function recover (bool): Whether or not recover the wrapped function
outputs to the `template_arg_name_` type. Defaults to True. outputs to the `template_arg_name_` type. Defaults to True.
Raises: Raises:
...@@ -200,16 +204,22 @@ def array_converter(to_torch=True, ...@@ -200,16 +204,22 @@ def array_converter(to_torch=True,
class ArrayConverter: class ArrayConverter:
"""Utility class for data-type agnostic processing.
Args:
template_array (tuple | list | int | float | np.ndarray |
torch.Tensor, optional): template array. Defaults to None.
"""
SUPPORTED_NON_ARRAY_TYPES = (int, float, np.int8, np.int16, np.int32, SUPPORTED_NON_ARRAY_TYPES = (int, float, np.int8, np.int16, np.int32,
np.int64, np.uint8, np.uint16, np.uint32, np.int64, np.uint8, np.uint16, np.uint32,
np.uint64, np.float16, np.float32, np.float64) np.uint64, np.float16, np.float32, np.float64)
def __init__(self, template_array=None): def __init__(self,
template_array: Optional[TemplateArrayType] = None) -> None:
if template_array is not None: if template_array is not None:
self.set_template(template_array) self.set_template(template_array)
def set_template(self, array): def set_template(self, array: TemplateArrayType) -> None:
"""Set template array. """Set template array.
Args: Args:
...@@ -250,16 +260,20 @@ class ArrayConverter: ...@@ -250,16 +260,20 @@ class ArrayConverter:
raise TypeError(f'Template type {self.array_type}' raise TypeError(f'Template type {self.array_type}'
f' is not supported.') f' is not supported.')
def convert(self, input_array, target_type=None, target_array=None): def convert(
self,
input_array: TemplateArrayType,
target_type: Optional[type] = None,
target_array: OptArrayType = None
) -> Union[np.ndarray, torch.Tensor]:
"""Convert input array to target data type. """Convert input array to target data type.
Args: Args:
input_array (tuple | list | np.ndarray | input_array (tuple | list | int | float | np.ndarray |
torch.Tensor | int | float ): torch.Tensor): Input array.
Input array. Defaults to None. target_type (:class:`np.ndarray` or :class:`torch.Tensor`,
target_type (<class 'np.ndarray'> | <class 'torch.Tensor'>, optional): Type to which input array is converted.
optional): Defaults to None.
Type to which input array is converted. Defaults to None.
target_array (np.ndarray | torch.Tensor, optional): target_array (np.ndarray | torch.Tensor, optional):
Template array to which input array is converted. Template array to which input array is converted.
Defaults to None. Defaults to None.
...@@ -270,6 +284,9 @@ class ArrayConverter: ...@@ -270,6 +284,9 @@ class ArrayConverter:
TypeError: If input type does not belong to the above range, TypeError: If input type does not belong to the above range,
or the contents of a list or tuple do not share the or the contents of a list or tuple do not share the
same data type, a TypeError is raised. same data type, a TypeError is raised.
Returns:
np.ndarray or torch.Tensor: The converted array.
""" """
if isinstance(input_array, (list, tuple)): if isinstance(input_array, (list, tuple)):
try: try:
...@@ -309,7 +326,17 @@ class ArrayConverter: ...@@ -309,7 +326,17 @@ class ArrayConverter:
converted_array = target_array.new_tensor(input_array) converted_array = target_array.new_tensor(input_array)
return converted_array return converted_array
def recover(self, input_array): def recover(
self, input_array: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
"""Recover input type to original array type.
Args:
input_array (np.ndarray | torch.Tensor): Input array.
Returns:
np.ndarray or torch.Tensor: Converted array.
"""
assert isinstance(input_array, (np.ndarray, torch.Tensor)), \ assert isinstance(input_array, (np.ndarray, torch.Tensor)), \
'invalid input array type' 'invalid input array type'
if isinstance(input_array, self.array_type): if isinstance(input_array, self.array_type):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment