"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "4c4da32c0ab90ad7ec8b08e42f6d2005e9de4acb"
Commit 80b39bd0 authored by zhangwenwei's avatar zhangwenwei
Browse files

Reformat docstrings in code

parent 64d7fbc2
...@@ -10,7 +10,7 @@ def box3d_multiclass_nms(mlvl_bboxes, ...@@ -10,7 +10,7 @@ def box3d_multiclass_nms(mlvl_bboxes,
max_num, max_num,
cfg, cfg,
mlvl_dir_scores=None): mlvl_dir_scores=None):
"""Multi-class nms for 3D boxes """Multi-class nms for 3D boxes.
Args: Args:
mlvl_bboxes (torch.Tensor): Multi-level boxes with shape (N, M). mlvl_bboxes (torch.Tensor): Multi-level boxes with shape (N, M).
......
import os.path as osp
import mmcv import mmcv
import numpy as np import numpy as np
import trimesh import trimesh
from os import path as osp
def _write_ply(points, out_filename): def _write_ply(points, out_filename):
"""Write points into ply format for meshlab visualization """Write points into ply format for meshlab visualization.
Args: Args:
points (np.ndarray): Points in shape (N, dim). points (np.ndarray): Points in shape (N, dim).
...@@ -28,7 +27,7 @@ def _write_ply(points, out_filename): ...@@ -28,7 +27,7 @@ def _write_ply(points, out_filename):
def _write_oriented_bbox(scene_bbox, out_filename): def _write_oriented_bbox(scene_bbox, out_filename):
"""Export oriented (around Z axis) scene bbox to meshes """Export oriented (around Z axis) scene bbox to meshes.
Args: Args:
scene_bbox(list[ndarray] or ndarray): xyz pos of center and scene_bbox(list[ndarray] or ndarray): xyz pos of center and
......
...@@ -4,7 +4,7 @@ from . import voxel_generator ...@@ -4,7 +4,7 @@ from . import voxel_generator
def build_voxel_generator(cfg, **kwargs): def build_voxel_generator(cfg, **kwargs):
"""Builder of voxel generator""" """Builder of voxel generator."""
if isinstance(cfg, voxel_generator.VoxelGenerator): if isinstance(cfg, voxel_generator.VoxelGenerator):
return cfg return cfg
elif isinstance(cfg, dict): elif isinstance(cfg, dict):
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
class VoxelGenerator(object): class VoxelGenerator(object):
"""Voxel generator in numpy implementation """Voxel generator in numpy implementation.
Args: Args:
voxel_size (list[float]): Size of a single voxel voxel_size (list[float]): Size of a single voxel
...@@ -33,7 +33,7 @@ class VoxelGenerator(object): ...@@ -33,7 +33,7 @@ class VoxelGenerator(object):
self._grid_size = grid_size self._grid_size = grid_size
def generate(self, points): def generate(self, points):
"""Generate voxels given points""" """Generate voxels given points."""
return points_to_voxel(points, self._voxel_size, return points_to_voxel(points, self._voxel_size,
self._point_cloud_range, self._max_num_points, self._point_cloud_range, self._max_num_points,
True, self._max_voxels) True, self._max_voxels)
......
import os.path as osp
import tempfile
import mmcv import mmcv
import numpy as np import numpy as np
import tempfile
from os import path as osp
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
...@@ -12,7 +11,7 @@ from .pipelines import Compose ...@@ -12,7 +11,7 @@ from .pipelines import Compose
@DATASETS.register_module() @DATASETS.register_module()
class Custom3DDataset(Dataset): class Custom3DDataset(Dataset):
"""Customized 3D dataset """Customized 3D dataset.
This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI
dataset. dataset.
...@@ -179,7 +178,7 @@ class Custom3DDataset(Dataset): ...@@ -179,7 +178,7 @@ class Custom3DDataset(Dataset):
from mmdet3d.core.evaluation import indoor_eval from mmdet3d.core.evaluation import indoor_eval
assert isinstance( assert isinstance(
results, list), f'Expect results to be list, got {type(results)}.' results, list), f'Expect results to be list, got {type(results)}.'
assert len(results) > 0, f'Expect length of results > 0.' assert len(results) > 0, 'Expect length of results > 0.'
assert len(results) == len(self.data_infos) assert len(results) == len(self.data_infos)
assert isinstance( assert isinstance(
results[0], dict results[0], dict
...@@ -220,8 +219,7 @@ class Custom3DDataset(Dataset): ...@@ -220,8 +219,7 @@ class Custom3DDataset(Dataset):
"""Set flag according to image aspect ratio. """Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1, Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. otherwise group 0. In 3D datasets, they are all the same, thus are all
In 3D datasets, they are all the same, thus are all zeros zeros
""" """
self.flag = np.zeros(len(self), dtype=np.uint8) self.flag = np.zeros(len(self), dtype=np.uint8)
import copy import copy
import os
import os.path as osp
import tempfile
import mmcv import mmcv
import numpy as np import numpy as np
import os
import tempfile
import torch import torch
from mmcv.utils import print_log from mmcv.utils import print_log
from os import path as osp
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from ..core import show_result from ..core import show_result
...@@ -16,7 +15,7 @@ from .custom_3d import Custom3DDataset ...@@ -16,7 +15,7 @@ from .custom_3d import Custom3DDataset
@DATASETS.register_module() @DATASETS.register_module()
class KittiDataset(Custom3DDataset): class KittiDataset(Custom3DDataset):
"""KITTI Dataset """KITTI Dataset.
This class serves as the API for experiments on the KITTI Dataset. This class serves as the API for experiments on the KITTI Dataset.
......
import os.path as osp
import tempfile
import mmcv import mmcv
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tempfile
from lyft_dataset_sdk.lyftdataset import LyftDataset as Lyft from lyft_dataset_sdk.lyftdataset import LyftDataset as Lyft
from lyft_dataset_sdk.utils.data_classes import Box as LyftBox from lyft_dataset_sdk.utils.data_classes import Box as LyftBox
from os import path as osp
from pyquaternion import Quaternion from pyquaternion import Quaternion
from mmdet3d.core.evaluation.lyft_eval import lyft_eval from mmdet3d.core.evaluation.lyft_eval import lyft_eval
...@@ -16,7 +15,7 @@ from .custom_3d import Custom3DDataset ...@@ -16,7 +15,7 @@ from .custom_3d import Custom3DDataset
@DATASETS.register_module() @DATASETS.register_module()
class LyftDataset(Custom3DDataset): class LyftDataset(Custom3DDataset):
"""Lyft Dataset """Lyft Dataset.
This class serves as the API for experiments on the Lyft Dataset. This class serves as the API for experiments on the Lyft Dataset.
......
import os.path as osp
import tempfile
import mmcv import mmcv
import numpy as np import numpy as np
import pyquaternion import pyquaternion
import tempfile
from nuscenes.utils.data_classes import Box as NuScenesBox from nuscenes.utils.data_classes import Box as NuScenesBox
from os import path as osp
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from ..core import show_result from ..core import show_result
...@@ -14,7 +13,7 @@ from .custom_3d import Custom3DDataset ...@@ -14,7 +13,7 @@ from .custom_3d import Custom3DDataset
@DATASETS.register_module() @DATASETS.register_module()
class NuScenesDataset(Custom3DDataset): class NuScenesDataset(Custom3DDataset):
"""NuScenes Dataset """NuScenes Dataset.
This class serves as the API for experiments on the NuScenes Dataset. This class serves as the API for experiments on the NuScenes Dataset.
......
import warnings
import numba import numba
import numpy as np import numpy as np
import warnings
from numba.errors import NumbaPerformanceWarning from numba.errors import NumbaPerformanceWarning
from mmdet3d.core.bbox import box_np_ops from mmdet3d.core.bbox import box_np_ops
...@@ -44,11 +43,11 @@ def box_collision_test(boxes, qboxes, clockwise=True): ...@@ -44,11 +43,11 @@ def box_collision_test(boxes, qboxes, clockwise=True):
max(boxes_standup[i, 1], qboxes_standup[j, 1])) max(boxes_standup[i, 1], qboxes_standup[j, 1]))
if ih > 0: if ih > 0:
for k in range(4): for k in range(4):
for l in range(4): for box_l in range(4):
A = lines_boxes[i, k, 0] A = lines_boxes[i, k, 0]
B = lines_boxes[i, k, 1] B = lines_boxes[i, k, 1]
C = lines_qboxes[j, l, 0] C = lines_qboxes[j, box_l, 0]
D = lines_qboxes[j, l, 1] D = lines_qboxes[j, box_l, 1]
acd = (D[1] - A[1]) * (C[0] - acd = (D[1] - A[1]) * (C[0] -
A[0]) > (C[1] - A[1]) * ( A[0]) > (C[1] - A[1]) * (
D[0] - A[0]) D[0] - A[0])
...@@ -71,15 +70,15 @@ def box_collision_test(boxes, qboxes, clockwise=True): ...@@ -71,15 +70,15 @@ def box_collision_test(boxes, qboxes, clockwise=True):
# now check complete overlap. # now check complete overlap.
# box overlap qbox: # box overlap qbox:
box_overlap_qbox = True box_overlap_qbox = True
for l in range(4): # point l in qboxes for box_l in range(4): # point l in qboxes
for k in range(4): # corner k in boxes for k in range(4): # corner k in boxes
vec = boxes[i, k] - boxes[i, (k + 1) % 4] vec = boxes[i, k] - boxes[i, (k + 1) % 4]
if clockwise: if clockwise:
vec = -vec vec = -vec
cross = vec[1] * ( cross = vec[1] * (
boxes[i, k, 0] - qboxes[j, l, 0]) boxes[i, k, 0] - qboxes[j, box_l, 0])
cross -= vec[0] * ( cross -= vec[0] * (
boxes[i, k, 1] - qboxes[j, l, 1]) boxes[i, k, 1] - qboxes[j, box_l, 1])
if cross >= 0: if cross >= 0:
box_overlap_qbox = False box_overlap_qbox = False
break break
...@@ -88,15 +87,15 @@ def box_collision_test(boxes, qboxes, clockwise=True): ...@@ -88,15 +87,15 @@ def box_collision_test(boxes, qboxes, clockwise=True):
if box_overlap_qbox is False: if box_overlap_qbox is False:
qbox_overlap_box = True qbox_overlap_box = True
for l in range(4): # point l in boxes for box_l in range(4): # point box_l in boxes
for k in range(4): # corner k in qboxes for k in range(4): # corner k in qboxes
vec = qboxes[j, k] - qboxes[j, (k + 1) % 4] vec = qboxes[j, k] - qboxes[j, (k + 1) % 4]
if clockwise: if clockwise:
vec = -vec vec = -vec
cross = vec[1] * ( cross = vec[1] * (
qboxes[j, k, 0] - boxes[i, l, 0]) qboxes[j, k, 0] - boxes[i, box_l, 0])
cross -= vec[0] * ( cross -= vec[0] * (
qboxes[j, k, 1] - boxes[i, l, 1]) qboxes[j, k, 1] - boxes[i, box_l, 1])
if cross >= 0: # if cross >= 0: #
qbox_overlap_box = False qbox_overlap_box = False
break break
...@@ -264,8 +263,8 @@ def noise_per_object_v3_(gt_boxes, ...@@ -264,8 +263,8 @@ def noise_per_object_v3_(gt_boxes,
center_noise_std=1.0, center_noise_std=1.0,
global_random_rot_range=np.pi / 4, global_random_rot_range=np.pi / 4,
num_try=100): num_try=100):
"""random rotate or remove each groundtrutn independently. """random rotate or remove each groundtrutn independently. use kitti viewer
use kitti viewer to test this function points_transform_ to test this function points_transform_
Args: Args:
gt_boxes: [N, 7], gt box in lidar.points_transform_ gt_boxes: [N, 7], gt box in lidar.points_transform_
......
import copy import copy
import numpy as np
import os import os
import pickle import pickle
import numpy as np
from mmdet3d.core.bbox import box_np_ops from mmdet3d.core.bbox import box_np_ops
from mmdet3d.datasets.pipelines import data_augment_utils from mmdet3d.datasets.pipelines import data_augment_utils
from ..registry import OBJECTSAMPLERS from ..registry import OBJECTSAMPLERS
......
...@@ -7,7 +7,7 @@ from mmdet.datasets.pipelines import LoadAnnotations ...@@ -7,7 +7,7 @@ from mmdet.datasets.pipelines import LoadAnnotations
@PIPELINES.register_module() @PIPELINES.register_module()
class LoadMultiViewImageFromFiles(object): class LoadMultiViewImageFromFiles(object):
""" Load multi channel images from a list of separate channel files. """Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames Expects results['img_filename'] to be a list of filenames
""" """
...@@ -43,7 +43,7 @@ class LoadMultiViewImageFromFiles(object): ...@@ -43,7 +43,7 @@ class LoadMultiViewImageFromFiles(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class LoadPointsFromMultiSweeps(object): class LoadPointsFromMultiSweeps(object):
"""Load points from multiple sweeps """Load points from multiple sweeps.
This is usually used for nuScenes dataset to utilize previous sweeps. This is usually used for nuScenes dataset to utilize previous sweeps.
...@@ -143,7 +143,7 @@ class PointSegClassMapping(object): ...@@ -143,7 +143,7 @@ class PointSegClassMapping(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class NormalizePointsColor(object): class NormalizePointsColor(object):
"""Normalize color of points """Normalize color of points.
Normalize color of the points. Normalize color of the points.
......
import mmcv
import warnings import warnings
from copy import deepcopy from copy import deepcopy
import mmcv
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import Compose from mmdet.datasets.pipelines import Compose
@PIPELINES.register_module() @PIPELINES.register_module()
class MultiScaleFlipAug3D(object): class MultiScaleFlipAug3D(object):
"""Test-time augmentation with multiple scales and flipping """Test-time augmentation with multiple scales and flipping.
Args: Args:
transforms (list[dict]): Transforms to apply in each augmentation. transforms (list[dict]): Transforms to apply in each augmentation.
......
...@@ -91,7 +91,7 @@ class RandomFlip3D(RandomFlip): ...@@ -91,7 +91,7 @@ class RandomFlip3D(RandomFlip):
@PIPELINES.register_module() @PIPELINES.register_module()
class ObjectSample(object): class ObjectSample(object):
"""Sample GT objects to the data """Sample GT objects to the data.
Args: Args:
db_sampler (dict): Config dict of the database sampler. db_sampler (dict): Config dict of the database sampler.
...@@ -168,7 +168,7 @@ class ObjectSample(object): ...@@ -168,7 +168,7 @@ class ObjectSample(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class ObjectNoise(object): class ObjectNoise(object):
"""Apply noise to each GT objects in the scene """Apply noise to each GT objects in the scene.
Args: Args:
translation_std (list, optional): Standard deviation of the translation_std (list, optional): Standard deviation of the
...@@ -221,7 +221,7 @@ class ObjectNoise(object): ...@@ -221,7 +221,7 @@ class ObjectNoise(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class GlobalRotScaleTrans(object): class GlobalRotScaleTrans(object):
"""Apply global rotation, scaling and translation to a 3D scene """Apply global rotation, scaling and translation to a 3D scene.
Args: Args:
rot_range (list[float]): Range of rotation angle. rot_range (list[float]): Range of rotation angle.
...@@ -374,7 +374,7 @@ class PointsRangeFilter(object): ...@@ -374,7 +374,7 @@ class PointsRangeFilter(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class ObjectNameFilter(object): class ObjectNameFilter(object):
"""Filter GT objects by their names """Filter GT objects by their names.
Args: Args:
classes (list[str]): list of class names to be kept for training classes (list[str]): list of class names to be kept for training
......
import os.path as osp
import numpy as np import numpy as np
from os import path as osp
from mmdet3d.core import show_result from mmdet3d.core import show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
...@@ -10,7 +9,7 @@ from .custom_3d import Custom3DDataset ...@@ -10,7 +9,7 @@ from .custom_3d import Custom3DDataset
@DATASETS.register_module() @DATASETS.register_module()
class ScanNetDataset(Custom3DDataset): class ScanNetDataset(Custom3DDataset):
"""ScanNet Dataset """ScanNet Dataset.
This class serves as the API for experiments on the ScanNet Dataset. This class serves as the API for experiments on the ScanNet Dataset.
......
import os.path as osp
import numpy as np import numpy as np
from os import path as osp
from mmdet3d.core import show_result from mmdet3d.core import show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
...@@ -10,7 +9,7 @@ from .custom_3d import Custom3DDataset ...@@ -10,7 +9,7 @@ from .custom_3d import Custom3DDataset
@DATASETS.register_module() @DATASETS.register_module()
class SUNRGBDDataset(Custom3DDataset): class SUNRGBDDataset(Custom3DDataset):
"""SUNRGBD Dataset """SUNRGBD Dataset.
This class serves as the API for experiments on the SUNRGBD Dataset. This class serves as the API for experiments on the SUNRGBD Dataset.
......
import torch import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from torch import nn as nn
from mmdet3d.ops import PointFPModule, PointSAModule from mmdet3d.ops import PointFPModule, PointSAModule
from mmdet.models import BACKBONES from mmdet.models import BACKBONES
......
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from torch import nn as nn
from mmdet.models import BACKBONES from mmdet.models import BACKBONES
@BACKBONES.register_module() @BACKBONES.register_module()
class SECOND(nn.Module): class SECOND(nn.Module):
"""Backbone network for SECOND/PointPillars/PartA2/MVXNet """Backbone network for SECOND/PointPillars/PartA2/MVXNet.
Args: Args:
in_channels (int): Input channels in_channels (int): Input channels
......
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from mmcv.cnn import bias_init_with_prob, normal_init from mmcv.cnn import bias_init_with_prob, normal_init
from torch import nn as nn
from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period, from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
xywhr2xyxyr) xywhr2xyxyr)
...@@ -244,7 +244,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -244,7 +244,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
@staticmethod @staticmethod
def add_sin_difference(boxes1, boxes2): def add_sin_difference(boxes1, boxes2):
"""Convert the rotation difference to difference in sine function """Convert the rotation difference to difference in sine function.
Args: Args:
boxes1 (torch.Tensor): shape (NxC), where C>=7 and boxes1 (torch.Tensor): shape (NxC), where C>=7 and
......
import torch import torch
import torch.nn.functional as F from torch.nn import functional as F
from mmdet3d.core.bbox import bbox_overlaps_nearest_3d from mmdet3d.core.bbox import bbox_overlaps_nearest_3d
from mmdet.models import HEADS from mmdet.models import HEADS
...@@ -9,7 +9,7 @@ from .train_mixins import get_direction_target ...@@ -9,7 +9,7 @@ from .train_mixins import get_direction_target
@HEADS.register_module() @HEADS.register_module()
class FreeAnchor3DHead(Anchor3DHead): class FreeAnchor3DHead(Anchor3DHead):
"""`FreeAnchor <https://arxiv.org/abs/1909.02466>`_ head for 3D detection """`FreeAnchor <https://arxiv.org/abs/1909.02466>`_ head for 3D detection.
Note: Note:
This implementation is directly modified from the `mmdet implementation This implementation is directly modified from the `mmdet implementation
...@@ -237,7 +237,7 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -237,7 +237,7 @@ class FreeAnchor3DHead(Anchor3DHead):
return losses return losses
def positive_bag_loss(self, matched_cls_prob, matched_box_prob): def positive_bag_loss(self, matched_cls_prob, matched_box_prob):
"""Generate positive bag loss """Generate positive bag loss.
Args: Args:
matched_cls_prob (torch.Tensor): Classification probability matched_cls_prob (torch.Tensor): Classification probability
...@@ -259,7 +259,7 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -259,7 +259,7 @@ class FreeAnchor3DHead(Anchor3DHead):
bag_prob, torch.ones_like(bag_prob), reduction='none') bag_prob, torch.ones_like(bag_prob), reduction='none')
def negative_bag_loss(self, cls_prob, box_prob): def negative_bag_loss(self, cls_prob, box_prob):
"""Generate negative bag loss """Generate negative bag loss.
Args: Args:
cls_prob (torch.Tensor): Classification probability cls_prob (torch.Tensor): Classification probability
......
...@@ -11,7 +11,7 @@ from .anchor3d_head import Anchor3DHead ...@@ -11,7 +11,7 @@ from .anchor3d_head import Anchor3DHead
@HEADS.register_module() @HEADS.register_module()
class PartA2RPNHead(Anchor3DHead): class PartA2RPNHead(Anchor3DHead):
"""RPN head for PartA2 """RPN head for PartA2.
Note: Note:
The main difference between the PartA2 RPN head and the Anchor3DHead The main difference between the PartA2 RPN head and the Anchor3DHead
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment