"...matting-anything_pytorch.git" did not exist on "ce0e5303897d0add4c557ae55dd93a36fbfb793d"
Commit 53435c62 authored by Yezhen Cong's avatar Yezhen Cong Committed by Tai-Wang
Browse files

[Refactor] Refactor code structure and docstrings (#803)

* refactor points_in_boxes

* Merge same functions of three boxes

* More docstring fixes and unify x/y/z size

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Remove None in function param type

* Fix unittest

* Add comments for NMS functions

* Merge methods of Points

* Add unittest

* Add optional and default value

* Fix box conversion and add unittest

* Fix comments

* Add unit test

* Indent

* Fix CI

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Add unit test for box bev

* More unit tests and refine docstrings in box_np_ops

* Fix comment

* Add deprecation warning
parent 4f36084f
...@@ -13,7 +13,7 @@ class S3DISData(object): ...@@ -13,7 +13,7 @@ class S3DISData(object):
Args: Args:
root_path (str): Root path of the raw data. root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'Area_1'. split (str, optional): Set split type of the data. Default: 'Area_1'.
""" """
def __init__(self, root_path, split='Area_1'): def __init__(self, root_path, split='Area_1'):
...@@ -48,9 +48,11 @@ class S3DISData(object): ...@@ -48,9 +48,11 @@ class S3DISData(object):
This method gets information from the raw data. This method gets information from the raw data.
Args: Args:
num_workers (int): Number of threads to be used. Default: 4. num_workers (int, optional): Number of threads to be used.
has_label (bool): Whether the data has label. Default: True. Default: 4.
sample_id_list (list[int]): Index list of the sample. has_label (bool, optional): Whether the data has label.
Default: True.
sample_id_list (list[int], optional): Index list of the sample.
Default: None. Default: None.
Returns: Returns:
...@@ -154,10 +156,11 @@ class S3DISSegData(object): ...@@ -154,10 +156,11 @@ class S3DISSegData(object):
Args: Args:
data_root (str): Root path of the raw data. data_root (str): Root path of the raw data.
ann_file (str): The generated scannet infos. ann_file (str): The generated scannet infos.
split (str): Set split type of the data. Default: 'train'. split (str, optional): Set split type of the data. Default: 'train'.
num_points (int): Number of points in each data input. Default: 8192. num_points (int, optional): Number of points in each data input.
label_weight_func (function): Function to compute the label weight. Default: 8192.
Default: None. label_weight_func (function, optional): Function to compute the
label weight. Default: None.
""" """
def __init__(self, def __init__(self,
...@@ -209,7 +212,7 @@ class S3DISSegData(object): ...@@ -209,7 +212,7 @@ class S3DISSegData(object):
return label return label
def get_scene_idxs_and_label_weight(self): def get_scene_idxs_and_label_weight(self):
"""Compute scene_idxs for data sampling and label weight for loss \ """Compute scene_idxs for data sampling and label weight for loss
calculation. calculation.
We sample more times for scenes with more points. Label_weight is We sample more times for scenes with more points. Label_weight is
......
...@@ -13,7 +13,7 @@ class ScanNetData(object): ...@@ -13,7 +13,7 @@ class ScanNetData(object):
Args: Args:
root_path (str): Root path of the raw data. root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'train'. split (str, optional): Set split type of the data. Default: 'train'.
""" """
def __init__(self, root_path, split='train'): def __init__(self, root_path, split='train'):
...@@ -90,9 +90,11 @@ class ScanNetData(object): ...@@ -90,9 +90,11 @@ class ScanNetData(object):
This method gets information from the raw data. This method gets information from the raw data.
Args: Args:
num_workers (int): Number of threads to be used. Default: 4. num_workers (int, optional): Number of threads to be used.
has_label (bool): Whether the data has label. Default: True. Default: 4.
sample_id_list (list[int]): Index list of the sample. has_label (bool, optional): Whether the data has label.
Default: True.
sample_id_list (list[int], optional): Index list of the sample.
Default: None. Default: None.
Returns: Returns:
...@@ -201,10 +203,11 @@ class ScanNetSegData(object): ...@@ -201,10 +203,11 @@ class ScanNetSegData(object):
Args: Args:
data_root (str): Root path of the raw data. data_root (str): Root path of the raw data.
ann_file (str): The generated scannet infos. ann_file (str): The generated scannet infos.
split (str): Set split type of the data. Default: 'train'. split (str, optional): Set split type of the data. Default: 'train'.
num_points (int): Number of points in each data input. Default: 8192. num_points (int, optional): Number of points in each data input.
label_weight_func (function): Function to compute the label weight. Default: 8192.
Default: None. label_weight_func (function, optional): Function to compute the
label weight. Default: None.
""" """
def __init__(self, def __init__(self,
...@@ -261,7 +264,7 @@ class ScanNetSegData(object): ...@@ -261,7 +264,7 @@ class ScanNetSegData(object):
return label return label
def get_scene_idxs_and_label_weight(self): def get_scene_idxs_and_label_weight(self):
"""Compute scene_idxs for data sampling and label weight for loss \ """Compute scene_idxs for data sampling and label weight for loss
calculation. calculation.
We sample more times for scenes with more points. Label_weight is We sample more times for scenes with more points. Label_weight is
......
...@@ -42,7 +42,7 @@ class SUNRGBDInstance(object): ...@@ -42,7 +42,7 @@ class SUNRGBDInstance(object):
self.ymax = data[2] + data[4] self.ymax = data[2] + data[4]
self.box2d = np.array([self.xmin, self.ymin, self.xmax, self.ymax]) self.box2d = np.array([self.xmin, self.ymin, self.xmax, self.ymax])
self.centroid = np.array([data[5], data[6], data[7]]) self.centroid = np.array([data[5], data[6], data[7]])
# data[9] is dx (l), data[8] is dy (w), data[10] is dz (h) # data[9] is x_size (l), data[8] is y_size (w), data[10] is z_size (h)
# in our depth coordinate system, # in our depth coordinate system,
# l corresponds to the size along the x axis # l corresponds to the size along the x axis
self.size = np.array([data[9], data[8], data[10]]) * 2 self.size = np.array([data[9], data[8], data[10]]) * 2
...@@ -62,8 +62,8 @@ class SUNRGBDData(object): ...@@ -62,8 +62,8 @@ class SUNRGBDData(object):
Args: Args:
root_path (str): Root path of the raw data. root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'train'. split (str, optional): Set split type of the data. Default: 'train'.
use_v1 (bool): Whether to use v1. Default: False. use_v1 (bool, optional): Whether to use v1. Default: False.
""" """
def __init__(self, root_path, split='train', use_v1=False): def __init__(self, root_path, split='train', use_v1=False):
...@@ -128,9 +128,11 @@ class SUNRGBDData(object): ...@@ -128,9 +128,11 @@ class SUNRGBDData(object):
This method gets information from the raw data. This method gets information from the raw data.
Args: Args:
num_workers (int): Number of threads to be used. Default: 4. num_workers (int, optional): Number of threads to be used.
has_label (bool): Whether the data has label. Default: True. Default: 4.
sample_id_list (list[int]): Index list of the sample. has_label (bool, optional): Whether the data has label.
Default: True.
sample_id_list (list[int], optional): Index list of the sample.
Default: None. Default: None.
Returns: Returns:
......
...@@ -31,8 +31,8 @@ class Waymo2KITTI(object): ...@@ -31,8 +31,8 @@ class Waymo2KITTI(object):
save_dir (str): Directory to save data in KITTI format. save_dir (str): Directory to save data in KITTI format.
prefix (str): Prefix of filename. In general, 0 for training, 1 for prefix (str): Prefix of filename. In general, 0 for training, 1 for
validation and 2 for testing. validation and 2 for testing.
workers (str): Number of workers for the parallel process. workers (int, optional): Number of workers for the parallel process.
test_mode (bool): Whether in the test_mode. Default: False. test_mode (bool, optional): Whether in the test_mode. Default: False.
""" """
def __init__(self, def __init__(self,
...@@ -402,8 +402,8 @@ class Waymo2KITTI(object): ...@@ -402,8 +402,8 @@ class Waymo2KITTI(object):
camera projections corresponding with two returns. camera projections corresponding with two returns.
range_image_top_pose (:obj:`Transform`): Range image pixel pose for range_image_top_pose (:obj:`Transform`): Range image pixel pose for
top lidar. top lidar.
ri_index (int): 0 for the first return, 1 for the second return. ri_index (int, optional): 0 for the first return,
Default: 0. 1 for the second return. Default: 0.
Returns: Returns:
tuple[list[np.ndarray]]: (List of points with shape [N, 3], tuple[list[np.ndarray]]: (List of points with shape [N, 3],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment