Commit c2fe651f authored by zhangshilong's avatar zhangshilong Committed by ChaimZhu
Browse files

refactor directory

parent bc5806ba
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule from mmdet3d.models.layers import DGCNNFAModule, DGCNNGFModule
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
......
...@@ -4,7 +4,7 @@ from mmcv.cnn import ConvModule ...@@ -4,7 +4,7 @@ from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16 from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import build_sa_module from mmdet3d.models.layers.pointnet_modules import build_sa_module
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from .base_pointnet import BasePointNet from .base_pointnet import BasePointNet
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from mmcv.runner import auto_fp16 from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule, build_sa_module from mmdet3d.models.layers import PointFPModule, build_sa_module
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from .base_pointnet import BasePointNet from .base_pointnet import BasePointNet
......
...@@ -117,7 +117,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -117,7 +117,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
imgs = [_img[[2, 1, 0], ...] for _img in imgs] imgs = [_img[[2, 1, 0], ...] for _img in imgs]
# Normalization. # Normalization.
if self._enable_normalize: if self._enable_normalize:
imgs = [(_img - self.mean) / self.std for _img in imgs] imgs = [(_img.float() - self.mean) / self.std for _img in imgs]
# Pad and stack Tensor. # Pad and stack Tensor.
batch_imgs = stack_batch(imgs, self.pad_size_divisor, batch_imgs = stack_batch(imgs, self.pad_size_divisor,
self.pad_value) self.pad_value)
......
...@@ -8,8 +8,9 @@ from mmcv.runner import BaseModule, auto_fp16 ...@@ -8,8 +8,9 @@ from mmcv.runner import BaseModule, auto_fp16
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.core.utils.typing import ConfigType, SampleList
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils.typing import ConfigType
class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
......
...@@ -4,7 +4,7 @@ from typing import Tuple ...@@ -4,7 +4,7 @@ from typing import Tuple
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import ConvModule
from torch import Tensor from torch import Tensor
from mmdet3d.ops import DGCNNFPModule from mmdet3d.models.layers import DGCNNFPModule
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from .decode_head import Base3DDecodeHead from .decode_head import Base3DDecodeHead
......
...@@ -4,8 +4,8 @@ from typing import Tuple ...@@ -4,8 +4,8 @@ from typing import Tuple
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import ConvModule
from torch import Tensor from torch import Tensor
from mmdet3d.core.utils import ConfigType
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.utils.typing import ConfigType
from .pointnet2_head import PointNet2Head from .pointnet2_head import PointNet2Head
......
...@@ -5,9 +5,9 @@ from mmcv.cnn.bricks import ConvModule ...@@ -5,9 +5,9 @@ from mmcv.cnn.bricks import ConvModule
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.core.utils.typing import ConfigType from mmdet3d.models.layers import PointFPModule
from mmdet3d.ops import PointFPModule
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.utils.typing import ConfigType
from .decode_head import Base3DDecodeHead from .decode_head import Base3DDecodeHead
......
...@@ -7,11 +7,12 @@ import torch ...@@ -7,11 +7,12 @@ import torch
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import PseudoSampler, merge_aug_bboxes_3d from mmdet3d.models.task_modules import PseudoSampler
from mmdet3d.core.utils import ConfigType, InstanceList, OptConfigType from mmdet3d.models.test_time_augs import merge_aug_bboxes_3d
from mmdet3d.core.utils.typing import OptInstanceList
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import multi_apply from mmdet3d.utils.typing import (ConfigType, InstanceList, OptConfigType,
OptInstanceList)
from mmdet.models.utils import multi_apply
from .base_3d_dense_head import Base3DDenseHead from .base_3d_dense_head import Base3DDenseHead
from .train_mixins import AnchorTrainMixin from .train_mixins import AnchorTrainMixin
......
...@@ -7,9 +7,9 @@ from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init ...@@ -7,9 +7,9 @@ from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.core.utils import ConfigType, InstanceList, OptConfigType
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet3d.utils import ConfigType, InstanceList, OptConfigType
from mmdet.models.utils import multi_apply
from .base_mono3d_dense_head import BaseMono3DDenseHead from .base_mono3d_dense_head import BaseMono3DDenseHead
......
...@@ -10,9 +10,11 @@ from mmengine.data import InstanceData ...@@ -10,9 +10,11 @@ from mmengine.data import InstanceData
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor from torch import Tensor
from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr from mmdet3d.models.layers import box3d_multiclass_nms
from mmdet3d.core.utils import InstanceList, OptMultiConfig, SampleList from mmdet3d.structures import limit_period, xywhr2xyxyr
from mmdet.core.utils import select_single_mlvl from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils.typing import InstanceList, OptMultiConfig
from mmdet.models.utils import select_single_mlvl
class Base3DDenseHead(BaseModule, metaclass=ABCMeta): class Base3DDenseHead(BaseModule, metaclass=ABCMeta):
...@@ -170,7 +172,7 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta): ...@@ -170,7 +172,7 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta):
(num_instances, ) (num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape - labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ). (num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, - bbox_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where contains a tensor with shape (num_instances, C), where
C >= 7. C >= 7.
""" """
...@@ -220,7 +222,7 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta): ...@@ -220,7 +222,7 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta):
(num_instances, ) (num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape - labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ). (num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, - bbox_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where contains a tensor with shape (num_instances, C), where
C >= 7. C >= 7.
""" """
......
...@@ -6,7 +6,8 @@ from mmcv.runner import BaseModule ...@@ -6,7 +6,8 @@ from mmcv.runner import BaseModule
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
from torch import Tensor from torch import Tensor
from mmdet3d.core.utils import InstanceList, OptMultiConfig, SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import InstanceList, OptMultiConfig
class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta): class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
......
...@@ -8,13 +8,13 @@ from mmcv.runner import BaseModule, force_fp32 ...@@ -8,13 +8,13 @@ from mmcv.runner import BaseModule, force_fp32
from mmengine import InstanceData from mmengine import InstanceData
from torch import Tensor, nn from torch import Tensor, nn
from mmdet3d.core import (Det3DDataSample, circle_nms, draw_heatmap_gaussian, from mmdet3d.models.utils import (clip_sigmoid, draw_heatmap_gaussian,
gaussian_radius, xywhr2xyxyr) gaussian_radius)
from mmdet3d.core.post_processing import nms_bev
from mmdet3d.models import builder
from mmdet3d.models.utils import clip_sigmoid
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import multi_apply from mmdet3d.structures import Det3DDataSample, xywhr2xyxyr
from mmdet.models.utils import multi_apply
from .. import builder
from ..layers import circle_nms, nms_bev
@MODELS.register_module() @MODELS.register_module()
......
...@@ -8,13 +8,12 @@ from mmengine.data import InstanceData ...@@ -8,13 +8,12 @@ from mmengine.data import InstanceData
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import (box3d_multiclass_nms, limit_period, points_img2cam, from mmdet3d.models.layers import box3d_multiclass_nms
xywhr2xyxyr)
from mmdet3d.core.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList)
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import multi_apply from mmdet3d.structures import limit_period, points_img2cam, xywhr2xyxyr
from mmdet.core.utils import select_single_mlvl from mmdet3d.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList)
from mmdet.models.utils import multi_apply, select_single_mlvl
from .anchor_free_mono3d_head import AnchorFreeMono3DHead from .anchor_free_mono3d_head import AnchorFreeMono3DHead
RangeType = Sequence[Tuple[int, int]] RangeType = Sequence[Tuple[int, int]]
......
...@@ -5,9 +5,9 @@ import torch ...@@ -5,9 +5,9 @@ import torch
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.bbox import bbox_overlaps_nearest_3d
from mmdet3d.core.utils import InstanceList, OptInstanceList
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures import bbox_overlaps_nearest_3d
from mmdet3d.utils import InstanceList, OptInstanceList
from .anchor3d_head import Anchor3DHead from .anchor3d_head import Anchor3DHead
from .train_mixins import get_direction_target from .train_mixins import get_direction_target
......
...@@ -15,10 +15,11 @@ from torch import Tensor ...@@ -15,10 +15,11 @@ from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.models.layers import aligned_3d_nms
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import build_bbox_coder, multi_apply from mmdet3d.structures import BaseInstance3DBoxes, Det3DDataSample
from ...core import BaseInstance3DBoxes, Det3DDataSample, SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet.models.utils import multi_apply
from .base_conv_bbox_head import BaseConvBboxHead from .base_conv_bbox_head import BaseConvBboxHead
EPS = 1e-6 EPS = 1e-6
...@@ -204,7 +205,7 @@ class GroupFree3DHead(BaseModule): ...@@ -204,7 +205,7 @@ class GroupFree3DHead(BaseModule):
assert self.embed_dims == decoder_cross_posembeds['num_pos_feats'] assert self.embed_dims == decoder_cross_posembeds['num_pos_feats']
# bbox_coder # bbox_coder
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins self.num_dir_bins = self.bbox_coder.num_dir_bins
......
...@@ -8,15 +8,15 @@ from mmengine.data import InstanceData ...@@ -8,15 +8,15 @@ from mmengine.data import InstanceData
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import Det3DDataSample from mmdet3d.models.layers import EdgeFusionModule
from mmdet3d.core.bbox.builder import build_bbox_coder from mmdet3d.models.task_modules.builder import build_bbox_coder
from mmdet3d.core.utils import get_ellip_gaussian_2D
from mmdet3d.models.model_utils import EdgeFusionModule
from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices, from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices,
get_keypoints, handle_proj_objs) get_ellip_gaussian_2D, get_keypoints,
handle_proj_objs)
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet3d.structures import Det3DDataSample
from mmdet.models.utils import gaussian_radius, gen_gaussian_target from mmdet.models.utils import (gaussian_radius, gen_gaussian_target,
multi_apply)
from mmdet.models.utils.gaussian_target import (get_local_maximum, from mmdet.models.utils.gaussian_target import (get_local_maximum,
get_topk_from_heatmap, get_topk_from_heatmap,
transpose_and_gather_feat) transpose_and_gather_feat)
......
...@@ -7,10 +7,11 @@ from mmcv import ConfigDict ...@@ -7,10 +7,11 @@ from mmcv import ConfigDict
from mmengine.data import InstanceData from mmengine.data import InstanceData
from torch import Tensor from torch import Tensor
from mmdet3d.core import limit_period, xywhr2xyxyr from mmdet3d.models.layers import nms_bev, nms_normal_bev
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.core.utils import InstanceList, SampleList
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures import limit_period, xywhr2xyxyr
from mmdet3d.utils.typing import InstanceList
from ...structures.det3d_data_sample import SampleList
from .anchor3d_head import Anchor3DHead from .anchor3d_head import Anchor3DHead
......
...@@ -9,12 +9,13 @@ from torch import Tensor ...@@ -9,12 +9,13 @@ from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import box3d_multiclass_nms, xywhr2xyxyr from mmdet3d.models.layers import box3d_multiclass_nms
from mmdet3d.core.bbox import points_cam2img, points_img2cam
from mmdet3d.core.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList)
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.core import distance2bbox, multi_apply from mmdet3d.structures import points_cam2img, points_img2cam, xywhr2xyxyr
from mmdet3d.utils.typing import (ConfigType, InstanceList, OptConfigType,
OptInstanceList)
from mmdet.models.utils import multi_apply
from mmdet.structures.bbox import distance2bbox
from .fcos_mono3d_head import FCOSMono3DHead from .fcos_mono3d_head import FCOSMono3DHead
...@@ -1138,7 +1139,7 @@ class PGDHead(FCOSMono3DHead): ...@@ -1138,7 +1139,7 @@ class PGDHead(FCOSMono3DHead):
points (list[Tensor]): Points of each fpn level, each has shape points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2). (num_points, 2).
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes_3d``、 gt_instance_3d. It usually includes ``bbox_3d``、
``labels_3d``、``depths``、``centers_2d`` and attributes. ``labels_3d``、``depths``、``centers_2d`` and attributes.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``、``labels``. gt_instance. It usually includes ``bboxes``、``labels``.
......
...@@ -3,13 +3,13 @@ import torch ...@@ -3,13 +3,13 @@ import torch
from mmcv.runner import BaseModule, force_fp32 from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import xywhr2xyxyr
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
LiDARInstance3DBoxes)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS from mmdet3d.models.layers import nms_bev, nms_normal_bev
from mmdet.core import build_bbox_coder, multi_apply from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures import xywhr2xyxyr
from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes,
LiDARInstance3DBoxes)
from mmdet.models.utils import multi_apply
@MODELS.register_module() @MODELS.register_module()
...@@ -54,7 +54,7 @@ class PointRPNHead(BaseModule): ...@@ -54,7 +54,7 @@ class PointRPNHead(BaseModule):
self.cls_loss = build_loss(cls_loss) self.cls_loss = build_loss(cls_loss)
# build box coder # build box coder
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
# build pred conv # build pred conv
self.cls_layers = self._make_fc_layers( self.cls_layers = self._make_fc_layers(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment