Unverified Commit 84b132b0 authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

fix basemodule and init_weights (#1714)

parent 86f6183d
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import warnings import warnings
from abc import ABCMeta from abc import ABCMeta
from mmcv.runner import BaseModule from mmengine.model import BaseModule
class BasePointNet(BaseModule, metaclass=ABCMeta): class BasePointNet(BaseModule, metaclass=ABCMeta):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule, auto_fp16 from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.layers import DGCNNFAModule, DGCNNGFModule from mmdet3d.models.layers import DGCNNFAModule, DGCNNGFModule
...@@ -71,7 +71,6 @@ class DGCNNBackbone(BaseModule): ...@@ -71,7 +71,6 @@ class DGCNNBackbone(BaseModule):
self.FA_module = DGCNNFAModule( self.FA_module = DGCNNFAModule(
mlp_channels=cur_fa_mlps, act_cfg=act_cfg) mlp_channels=cur_fa_mlps, act_cfg=act_cfg)
@auto_fp16(apply_to=('points', ))
def forward(self, points): def forward(self, points):
"""Forward pass. """Forward pass.
......
...@@ -3,7 +3,7 @@ import warnings ...@@ -3,7 +3,7 @@ import warnings
import torch import torch
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn from torch import nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
......
...@@ -4,7 +4,7 @@ import warnings ...@@ -4,7 +4,7 @@ import warnings
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16 from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.builder import build_backbone from mmdet3d.models.builder import build_backbone
...@@ -90,7 +90,6 @@ class MultiBackbone(BaseModule): ...@@ -90,7 +90,6 @@ class MultiBackbone(BaseModule):
'please use "init_cfg" instead') 'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@auto_fp16()
def forward(self, points): def forward(self, points):
"""Forward pass. """Forward pass.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import warnings import warnings
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
......
...@@ -3,8 +3,8 @@ from abc import ABCMeta, abstractmethod ...@@ -3,8 +3,8 @@ from abc import ABCMeta, abstractmethod
from typing import List from typing import List
import torch import torch
from mmcv.cnn import normal_init from mmengine.model import BaseModule
from mmcv.runner import BaseModule, auto_fp16 from mmengine.model.utils import normal_init
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
...@@ -94,7 +94,6 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -94,7 +94,6 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
super().init_weights() super().init_weights()
normal_init(self.conv_seg, mean=0, std=0.01) normal_init(self.conv_seg, mean=0, std=0.01)
@auto_fp16()
@abstractmethod @abstractmethod
def forward(self, feats_dict: dict): def forward(self, feats_dict: dict):
"""Placeholder of forward function.""" """Placeholder of forward function."""
......
...@@ -3,7 +3,8 @@ from abc import abstractmethod ...@@ -3,7 +3,8 @@ from abc import abstractmethod
from typing import Any, List, Sequence, Tuple, Union from typing import Any, List, Sequence, Tuple, Union
import torch import torch
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init from mmcv.cnn import ConvModule
from mmengine.model.utils import bias_init_with_prob, normal_init
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
......
...@@ -4,10 +4,10 @@ from typing import List, Optional, Tuple ...@@ -4,10 +4,10 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn.utils.weight_init import constant_init
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
from mmengine.data import InstanceData from mmengine.data import InstanceData
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.model.utils import constant_init
from torch import Tensor from torch import Tensor
from mmdet3d.models.layers import box3d_multiclass_nms from mmdet3d.models.layers import box3d_multiclass_nms
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Optional, Tuple from typing import Optional, Tuple
from mmcv.runner import BaseModule
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from torch import Tensor from torch import Tensor
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
......
...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from mmcv.cnn import ConvModule, build_conv_layer from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule, force_fp32
from mmengine import InstanceData from mmengine import InstanceData
from mmengine.model import BaseModule
from torch import Tensor, nn from torch import Tensor, nn
from mmdet3d.models.utils import (clip_sigmoid, draw_heatmap_gaussian, from mmdet3d.models.utils import (clip_sigmoid, draw_heatmap_gaussian,
...@@ -608,7 +608,6 @@ class CenterHead(BaseModule): ...@@ -608,7 +608,6 @@ class CenterHead(BaseModule):
losses = self.loss_by_feat(outs, batch_gt_instance_3d) losses = self.loss_by_feat(outs, batch_gt_instance_3d)
return losses return losses
@force_fp32(apply_to=('preds_dicts'))
def loss_by_feat(self, preds_dicts: Tuple[List[dict]], def loss_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_gt_instances_3d: List[InstanceData], *args, batch_gt_instances_3d: List[InstanceData], *args,
**kwargs): **kwargs):
......
...@@ -3,8 +3,9 @@ from typing import List, Optional, Sequence, Tuple ...@@ -3,8 +3,9 @@ from typing import List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import Scale, normal_init from mmcv.cnn import Scale
from mmengine.data import InstanceData from mmengine.data import InstanceData
from mmengine.model.utils import normal_init
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
......
...@@ -4,13 +4,14 @@ from typing import Dict, List, Optional, Tuple ...@@ -4,13 +4,14 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule, xavier_init from mmcv.cnn import ConvModule
from mmcv.cnn.bricks.transformer import (build_positional_encoding, from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer) build_transformer_layer)
from mmcv.ops import PointsSampler as Points_Sampler from mmcv.ops import PointsSampler as Points_Sampler
from mmcv.ops import gather_points from mmcv.ops import gather_points
from mmcv.runner import BaseModule
from mmengine import InstanceData from mmengine import InstanceData
from mmengine.model import BaseModule
from mmengine.model.utils import xavier_init
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from mmcv.cnn import xavier_init
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
from mmengine.data import InstanceData from mmengine.data import InstanceData
from mmengine.model.utils import xavier_init
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
......
...@@ -3,8 +3,9 @@ from typing import List, Optional, Tuple ...@@ -3,8 +3,9 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import Scale, bias_init_with_prob, normal_init from mmcv.cnn import Scale
from mmengine.data import InstanceData from mmengine.data import InstanceData
from mmengine.model.utils import bias_init_with_prob, normal_init
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmcv.runner import BaseModule, force_fp32 from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
...@@ -124,7 +124,6 @@ class PointRPNHead(BaseModule): ...@@ -124,7 +124,6 @@ class PointRPNHead(BaseModule):
batch_size, -1, self._get_reg_out_channels()) batch_size, -1, self._get_reg_out_channels())
return point_box_preds, point_cls_preds return point_box_preds, point_cls_preds
@force_fp32(apply_to=('bbox_preds'))
def loss(self, def loss(self,
bbox_preds, bbox_preds,
cls_preds, cls_preds,
......
...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from mmcv.ops import furthest_point_sample from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule
from mmengine import ConfigDict, InstanceData from mmengine import ConfigDict, InstanceData
from mmengine.model import BaseModule
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, force_fp32 from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
...@@ -39,7 +39,6 @@ class DGCNNFAModule(BaseModule): ...@@ -39,7 +39,6 @@ class DGCNNFAModule(BaseModule):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg)) act_cfg=act_cfg))
@force_fp32()
def forward(self, points): def forward(self, points):
"""forward. """forward.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, force_fp32 from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
...@@ -38,7 +38,6 @@ class DGCNNFPModule(BaseModule): ...@@ -38,7 +38,6 @@ class DGCNNFPModule(BaseModule):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg)) act_cfg=act_cfg))
@force_fp32()
def forward(self, points): def forward(self, points):
"""forward. """forward.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment