Commit c2fe651f authored by zhangshilong's avatar zhangshilong Committed by ChaimZhu
Browse files

refactor directory

parent bc5806ba
...@@ -7,7 +7,7 @@ from mmcv.ops import QueryAndGroup, gather_points ...@@ -7,7 +7,7 @@ from mmcv.ops import QueryAndGroup, gather_points
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.ops import PAConv from mmdet3d.models.layers import PAConv
from .builder import SA_MODULES from .builder import SA_MODULES
......
...@@ -3,8 +3,8 @@ import torch ...@@ -3,8 +3,8 @@ import torch
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures import AxisAlignedBboxOverlaps3D
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from ...core.bbox import AxisAlignedBboxOverlaps3D
@weighted_loss @weighted_loss
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PAConv, PAConvCUDA
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.models.losses.utils import weight_reduce_loss from mmdet.models.losses.utils import weight_reduce_loss
from ..layers import PAConv, PAConvCUDA
def weight_correlation(conv): def weight_correlation(conv):
......
...@@ -4,8 +4,8 @@ from mmcv.ops import points_in_boxes_all, three_interpolate, three_nn ...@@ -4,8 +4,8 @@ from mmcv.ops import points_in_boxes_all, three_interpolate, three_nn
from mmcv.runner import auto_fp16 from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule from mmdet3d.models.layers import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
if IS_SPCONV2_AVAILABLE: if IS_SPCONV2_AVAILABLE:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
if IS_SPCONV2_AVAILABLE: if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import SparseConvTensor, SparseSequential from spconv.pytorch import SparseConvTensor, SparseSequential
...@@ -10,8 +10,8 @@ else: ...@@ -10,8 +10,8 @@ else:
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule from mmdet3d.models.layers import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops.sparse_block import replace_feature from mmdet3d.models.layers.sparse_block import replace_feature
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
......
# Copyright (c) OpenMMLab. All rights reserved.
from .edge_fusion_module import EdgeFusionModule
from .transformer import GroupFree3DMHA
from .vote_module import VoteModule
__all__ = ['VoteModule', 'GroupFree3DMHA', 'EdgeFusionModule']
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule from mmdet3d.models.layers.pointnet_modules import PointFPModule
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
......
...@@ -9,13 +9,13 @@ from torch import Tensor ...@@ -9,13 +9,13 @@ from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import BaseInstance3DBoxes, Det3DDataSample from mmdet3d.models import aligned_3d_nms
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.models.layers.pointnet_modules import build_sa_module
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.losses import chamfer_distance from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import multi_apply from mmdet3d.structures import (BaseInstance3DBoxes, DepthInstance3DBoxes,
Det3DDataSample)
from mmdet.models.utils import multi_apply
@MODELS.register_module() @MODELS.register_module()
......
...@@ -7,7 +7,9 @@ from mmcv.cnn import ConvModule, normal_init ...@@ -7,7 +7,9 @@ from mmcv.cnn import ConvModule, normal_init
from mmengine.data import InstanceData from mmengine.data import InstanceData
from torch import Tensor from torch import Tensor
from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE from mmdet3d.models import make_sparse_convmodule
from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
from mmdet.models.utils import multi_apply
if IS_SPCONV2_AVAILABLE: if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import (SparseConvTensor, SparseMaxPool3d, from spconv.pytorch import (SparseConvTensor, SparseMaxPool3d,
...@@ -18,14 +20,11 @@ else: ...@@ -18,14 +20,11 @@ else:
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import build_bbox_coder
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.ops import make_sparse_convmodule from mmdet3d.models.layers import nms_bev, nms_normal_bev
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import multi_apply from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr)
@MODELS.register_module() @MODELS.register_module()
...@@ -87,7 +86,7 @@ class PartA2BboxHead(BaseModule): ...@@ -87,7 +86,7 @@ class PartA2BboxHead(BaseModule):
super(PartA2BboxHead, self).__init__(init_cfg=init_cfg) super(PartA2BboxHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.with_corner_loss = with_corner_loss self.with_corner_loss = with_corner_loss
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.loss_bbox = build_loss(loss_bbox) self.loss_bbox = build_loss(loss_bbox)
self.loss_cls = build_loss(loss_cls) self.loss_cls = build_loss(loss_cls)
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
......
...@@ -6,14 +6,12 @@ from mmcv.cnn.bricks import build_conv_layer ...@@ -6,14 +6,12 @@ from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import build_bbox_coder from mmdet3d.models.layers import nms_bev, nms_normal_bev
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes, from mmdet3d.models.layers.pointnet_modules import build_sa_module
rotation_3d_in_axis, xywhr2xyxyr) from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes,
from mmdet3d.models.builder import build_loss rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.ops import build_sa_module from mmdet.models.utils import multi_apply
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply
@MODELS.register_module() @MODELS.register_module()
...@@ -100,9 +98,9 @@ class PointRCNNBboxHead(BaseModule): ...@@ -100,9 +98,9 @@ class PointRCNNBboxHead(BaseModule):
self.act_cfg = act_cfg self.act_cfg = act_cfg
self.bias = bias self.bias = bias
self.loss_bbox = build_loss(loss_bbox) self.loss_bbox = MODELS.build(loss_bbox)
self.loss_cls = build_loss(loss_cls) self.loss_cls = MODELS.build(loss_cls)
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.in_channels = in_channels self.in_channels = in_channels
......
...@@ -5,7 +5,7 @@ from mmengine import InstanceData ...@@ -5,7 +5,7 @@ from mmengine import InstanceData
from torch import Tensor from torch import Tensor
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from ...core import Det3DDataSample from mmdet3d.structures import Det3DDataSample
from .base_3droi_head import Base3DRoIHead from .base_3droi_head import Base3DRoIHead
......
...@@ -4,11 +4,11 @@ from mmcv.runner import BaseModule ...@@ -4,11 +4,11 @@ from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.bbox.structures import rotation_3d_in_axis
from mmdet3d.core.utils import InstanceList
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet3d.structures.bbox_3d import rotation_3d_in_axis
from mmdet3d.utils import InstanceList
from mmdet.models.utils import multi_apply
@MODELS.register_module() @MODELS.register_module()
......
...@@ -9,11 +9,10 @@ from mmengine import InstanceData ...@@ -9,11 +9,10 @@ from mmengine import InstanceData
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import Det3DDataSample from mmdet3d.models.layers import VoteModule, build_sa_module
from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet3d.structures import Det3DDataSample
from mmdet.models.utils import multi_apply
@MODELS.register_module() @MODELS.register_module()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment