Unverified Commit 84b132b0 authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

fix basemodule and init_weights (#1714)

parent 86f6183d
......@@ -2,7 +2,7 @@
import warnings
from abc import ABCMeta
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
class BasePointNet(BaseModule, metaclass=ABCMeta):
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule, auto_fp16
from mmengine.model import BaseModule
from torch import nn as nn
from mmdet3d.models.layers import DGCNNFAModule, DGCNNGFModule
......@@ -71,7 +71,6 @@ class DGCNNBackbone(BaseModule):
self.FA_module = DGCNNFAModule(
mlp_channels=cur_fa_mlps, act_cfg=act_cfg)
@auto_fp16(apply_to=('points', ))
def forward(self, points):
"""Forward pass.
......
......@@ -3,7 +3,7 @@ import warnings
import torch
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn
from mmdet3d.registry import MODELS
......
......@@ -4,7 +4,7 @@ import warnings
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from mmengine.model import BaseModule
from torch import nn as nn
from mmdet3d.models.builder import build_backbone
......@@ -90,7 +90,6 @@ class MultiBackbone(BaseModule):
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@auto_fp16()
def forward(self, points):
"""Forward pass.
......
......@@ -2,7 +2,7 @@
import warnings
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn as nn
from mmdet3d.registry import MODELS
......
......@@ -3,8 +3,8 @@ from abc import ABCMeta, abstractmethod
from typing import List
import torch
from mmcv.cnn import normal_init
from mmcv.runner import BaseModule, auto_fp16
from mmengine.model import BaseModule
from mmengine.model.utils import normal_init
from torch import Tensor
from torch import nn as nn
......@@ -94,7 +94,6 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
super().init_weights()
normal_init(self.conv_seg, mean=0, std=0.01)
@auto_fp16()
@abstractmethod
def forward(self, feats_dict: dict):
"""Placeholder of forward function."""
......
......@@ -3,7 +3,8 @@ from abc import abstractmethod
from typing import Any, List, Sequence, Tuple, Union
import torch
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmcv.cnn import ConvModule
from mmengine.model.utils import bias_init_with_prob, normal_init
from torch import Tensor
from torch import nn as nn
......
......@@ -4,10 +4,10 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
from mmcv.cnn.utils.weight_init import constant_init
from mmengine.config import ConfigDict
from mmengine.data import InstanceData
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init
from torch import Tensor
from mmdet3d.models.layers import box3d_multiclass_nms
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn as nn
from mmdet3d.registry import MODELS
......
......@@ -2,8 +2,8 @@
from abc import ABCMeta, abstractmethod
from typing import Optional, Tuple
from mmcv.runner import BaseModule
from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from torch import Tensor
from mmdet3d.structures.det3d_data_sample import SampleList
......
......@@ -4,8 +4,8 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule, force_fp32
from mmengine import InstanceData
from mmengine.model import BaseModule
from torch import Tensor, nn
from mmdet3d.models.utils import (clip_sigmoid, draw_heatmap_gaussian,
......@@ -608,7 +608,6 @@ class CenterHead(BaseModule):
losses = self.loss_by_feat(outs, batch_gt_instance_3d)
return losses
@force_fp32(apply_to=('preds_dicts'))
def loss_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_gt_instances_3d: List[InstanceData], *args,
**kwargs):
......
......@@ -3,8 +3,9 @@ from typing import List, Optional, Sequence, Tuple
import numpy as np
import torch
from mmcv.cnn import Scale, normal_init
from mmcv.cnn import Scale
from mmengine.data import InstanceData
from mmengine.model.utils import normal_init
from torch import Tensor
from torch import nn as nn
......
......@@ -4,13 +4,14 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from mmcv.cnn import ConvModule, xavier_init
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer)
from mmcv.ops import PointsSampler as Points_Sampler
from mmcv.ops import gather_points
from mmcv.runner import BaseModule
from mmengine import InstanceData
from mmengine.model import BaseModule
from mmengine.model.utils import xavier_init
from torch import Tensor
from torch import nn as nn
from torch.nn import functional as F
......
......@@ -2,9 +2,9 @@
from typing import List, Optional, Tuple, Union
import torch
from mmcv.cnn import xavier_init
from mmengine.config import ConfigDict
from mmengine.data import InstanceData
from mmengine.model.utils import xavier_init
from torch import Tensor
from torch import nn as nn
......
......@@ -3,8 +3,9 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
from mmcv.cnn import Scale, bias_init_with_prob, normal_init
from mmcv.cnn import Scale
from mmengine.data import InstanceData
from mmengine.model.utils import bias_init_with_prob, normal_init
from torch import Tensor
from torch import nn as nn
from torch.nn import functional as F
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import BaseModule, force_fp32
from mmengine.model import BaseModule
from torch import nn as nn
from mmdet3d.models.builder import build_loss
......@@ -124,7 +124,6 @@ class PointRPNHead(BaseModule):
batch_size, -1, self._get_reg_out_channels())
return point_box_preds, point_cls_preds
@force_fp32(apply_to=('bbox_preds'))
def loss(self,
bbox_preds,
cls_preds,
......
......@@ -4,8 +4,8 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule
from mmengine import ConfigDict, InstanceData
from mmengine.model import BaseModule
from torch import Tensor
from torch.nn import functional as F
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, force_fp32
from mmengine.model import BaseModule
from torch import nn as nn
......@@ -39,7 +39,6 @@ class DGCNNFAModule(BaseModule):
norm_cfg=norm_cfg,
act_cfg=act_cfg))
@force_fp32()
def forward(self, points):
"""forward.
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, force_fp32
from mmengine.model import BaseModule
from torch import nn as nn
......@@ -38,7 +38,6 @@ class DGCNNFPModule(BaseModule):
norm_cfg=norm_cfg,
act_cfg=act_cfg))
@force_fp32()
def forward(self, points):
"""forward.
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn as nn
from torch.nn import functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment