Commit 53435c62 authored by Yezhen Cong's avatar Yezhen Cong Committed by Tai-Wang
Browse files

[Refactor] Refactor code structure and docstrings (#803)

* refactor points_in_boxes

* Merge same functions of three boxes

* More docstring fixes and unify x/y/z size

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Remove None in function param type

* Fix unittest

* Add comments for NMS functions

* Merge methods of Points

* Add unittest

* Add optional and default value

* Fix box conversion and add unittest

* Fix comments

* Add unit test

* Indent

* Fix CI

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Add unit test for box bev

* More unit tests and refine docstrings in box_np_ops

* Fix comment

* Add deprecation warning
parent 4f36084f
......@@ -13,7 +13,7 @@ class S3DISData(object):
Args:
root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'Area_1'.
split (str, optional): Set split type of the data. Default: 'Area_1'.
"""
def __init__(self, root_path, split='Area_1'):
......@@ -48,9 +48,11 @@ class S3DISData(object):
This method gets information from the raw data.
Args:
num_workers (int): Number of threads to be used. Default: 4.
has_label (bool): Whether the data has label. Default: True.
sample_id_list (list[int]): Index list of the sample.
num_workers (int, optional): Number of threads to be used.
Default: 4.
has_label (bool, optional): Whether the data has label.
Default: True.
sample_id_list (list[int], optional): Index list of the sample.
Default: None.
Returns:
......@@ -154,10 +156,11 @@ class S3DISSegData(object):
Args:
data_root (str): Root path of the raw data.
ann_file (str): The generated scannet infos.
split (str): Set split type of the data. Default: 'train'.
num_points (int): Number of points in each data input. Default: 8192.
label_weight_func (function): Function to compute the label weight.
Default: None.
split (str, optional): Set split type of the data. Default: 'train'.
num_points (int, optional): Number of points in each data input.
Default: 8192.
label_weight_func (function, optional): Function to compute the
label weight. Default: None.
"""
def __init__(self,
......@@ -209,7 +212,7 @@ class S3DISSegData(object):
return label
def get_scene_idxs_and_label_weight(self):
"""Compute scene_idxs for data sampling and label weight for loss \
"""Compute scene_idxs for data sampling and label weight for loss
calculation.
We sample more times for scenes with more points. Label_weight is
......
......@@ -13,7 +13,7 @@ class ScanNetData(object):
Args:
root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'train'.
split (str, optional): Set split type of the data. Default: 'train'.
"""
def __init__(self, root_path, split='train'):
......@@ -90,9 +90,11 @@ class ScanNetData(object):
This method gets information from the raw data.
Args:
num_workers (int): Number of threads to be used. Default: 4.
has_label (bool): Whether the data has label. Default: True.
sample_id_list (list[int]): Index list of the sample.
num_workers (int, optional): Number of threads to be used.
Default: 4.
has_label (bool, optional): Whether the data has label.
Default: True.
sample_id_list (list[int], optional): Index list of the sample.
Default: None.
Returns:
......@@ -201,10 +203,11 @@ class ScanNetSegData(object):
Args:
data_root (str): Root path of the raw data.
ann_file (str): The generated scannet infos.
split (str): Set split type of the data. Default: 'train'.
num_points (int): Number of points in each data input. Default: 8192.
label_weight_func (function): Function to compute the label weight.
Default: None.
split (str, optional): Set split type of the data. Default: 'train'.
num_points (int, optional): Number of points in each data input.
Default: 8192.
label_weight_func (function, optional): Function to compute the
label weight. Default: None.
"""
def __init__(self,
......@@ -261,7 +264,7 @@ class ScanNetSegData(object):
return label
def get_scene_idxs_and_label_weight(self):
"""Compute scene_idxs for data sampling and label weight for loss \
"""Compute scene_idxs for data sampling and label weight for loss
calculation.
We sample more times for scenes with more points. Label_weight is
......
......@@ -42,7 +42,7 @@ class SUNRGBDInstance(object):
self.ymax = data[2] + data[4]
self.box2d = np.array([self.xmin, self.ymin, self.xmax, self.ymax])
self.centroid = np.array([data[5], data[6], data[7]])
# data[9] is dx (l), data[8] is dy (w), data[10] is dz (h)
# data[9] is x_size (l), data[8] is y_size (w), data[10] is z_size (h)
# in our depth coordinate system,
# l corresponds to the size along the x axis
self.size = np.array([data[9], data[8], data[10]]) * 2
......@@ -62,8 +62,8 @@ class SUNRGBDData(object):
Args:
root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'train'.
use_v1 (bool): Whether to use v1. Default: False.
split (str, optional): Set split type of the data. Default: 'train'.
use_v1 (bool, optional): Whether to use v1. Default: False.
"""
def __init__(self, root_path, split='train', use_v1=False):
......@@ -128,9 +128,11 @@ class SUNRGBDData(object):
This method gets information from the raw data.
Args:
num_workers (int): Number of threads to be used. Default: 4.
has_label (bool): Whether the data has label. Default: True.
sample_id_list (list[int]): Index list of the sample.
num_workers (int, optional): Number of threads to be used.
Default: 4.
has_label (bool, optional): Whether the data has label.
Default: True.
sample_id_list (list[int], optional): Index list of the sample.
Default: None.
Returns:
......
......@@ -31,8 +31,8 @@ class Waymo2KITTI(object):
save_dir (str): Directory to save data in KITTI format.
prefix (str): Prefix of filename. In general, 0 for training, 1 for
validation and 2 for testing.
workers (str): Number of workers for the parallel process.
test_mode (bool): Whether in the test_mode. Default: False.
workers (int, optional): Number of workers for the parallel process.
test_mode (bool, optional): Whether in the test_mode. Default: False.
"""
def __init__(self,
......@@ -402,8 +402,8 @@ class Waymo2KITTI(object):
camera projections corresponding with two returns.
range_image_top_pose (:obj:`Transform`): Range image pixel pose for
top lidar.
ri_index (int): 0 for the first return, 1 for the second return.
Default: 0.
ri_index (int, optional): 0 for the first return,
1 for the second return. Default: 0.
Returns:
tuple[list[np.ndarray]]: (List of points with shape [N, 3],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment