Unverified Commit 84b132b0 authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

fix basemodule and init_weights (#1714)

parent 86f6183d
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import copy import copy
import torch import torch
from mmcv.cnn import (ConvModule, build_activation_layer, build_norm_layer, from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
constant_init)
from mmcv.ops import assign_score_withk as assign_score_cuda from mmcv.ops import assign_score_withk as assign_score_cuda
from mmengine.model.utils import constant_init
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
......
...@@ -4,7 +4,7 @@ from typing import List ...@@ -4,7 +4,7 @@ from typing import List
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.ops import three_interpolate, three_nn from mmcv.ops import three_interpolate, three_nn
from mmcv.runner import BaseModule, force_fp32 from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
...@@ -37,7 +37,6 @@ class PointFPModule(BaseModule): ...@@ -37,7 +37,6 @@ class PointFPModule(BaseModule):
conv_cfg=dict(type='Conv2d'), conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg)) norm_cfg=norm_cfg))
@force_fp32()
def forward(self, target: torch.Tensor, source: torch.Tensor, def forward(self, target: torch.Tensor, source: torch.Tensor,
target_feats: torch.Tensor, target_feats: torch.Tensor,
source_feats: torch.Tensor) -> torch.Tensor: source_feats: torch.Tensor) -> torch.Tensor:
......
...@@ -8,7 +8,7 @@ if IS_SPCONV2_AVAILABLE: ...@@ -8,7 +8,7 @@ if IS_SPCONV2_AVAILABLE:
else: else:
from mmcv.ops import SparseConvTensor, SparseSequential from mmcv.ops import SparseConvTensor, SparseSequential
from mmcv.runner import BaseModule, auto_fp16 from mmengine.model import BaseModule
from mmdet3d.models.layers import SparseBasicBlock, make_sparse_convmodule from mmdet3d.models.layers import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.models.layers.sparse_block import replace_feature from mmdet3d.models.layers.sparse_block import replace_feature
...@@ -102,7 +102,6 @@ class SparseUNet(BaseModule): ...@@ -102,7 +102,6 @@ class SparseUNet(BaseModule):
indice_key='spconv_down2', indice_key='spconv_down2',
conv_type='SparseConv3d') conv_type='SparseConv3d')
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size): def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUNet. """Forward of SparseUNet.
......
...@@ -3,7 +3,7 @@ import math ...@@ -3,7 +3,7 @@ import math
import numpy as np import numpy as np
from mmcv.cnn import ConvModule, build_conv_layer from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.layers.pointnet_modules import PointFPModule from mmdet3d.models.layers.pointnet_modules import PointFPModule
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
from mmcv.runner import BaseModule, auto_fp16 from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -71,7 +71,6 @@ class SECONDFPN(BaseModule): ...@@ -71,7 +71,6 @@ class SECONDFPN(BaseModule):
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0) dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)
] ]
@auto_fp16()
def forward(self, x): def forward(self, x):
"""Forward function. """Forward function.
......
...@@ -3,8 +3,8 @@ from typing import Dict, List, Optional, Tuple ...@@ -3,8 +3,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmengine import InstanceData from mmengine import InstanceData
from mmengine.model import BaseModule
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
......
...@@ -3,8 +3,9 @@ from typing import List ...@@ -3,8 +3,9 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule, normal_init from mmcv.cnn import ConvModule
from mmengine.data import InstanceData from mmengine.data import InstanceData
from mmengine.model.utils import normal_init
from torch import Tensor from torch import Tensor
from mmdet3d.models import make_sparse_convmodule from mmdet3d.models import make_sparse_convmodule
...@@ -17,7 +18,7 @@ if IS_SPCONV2_AVAILABLE: ...@@ -17,7 +18,7 @@ if IS_SPCONV2_AVAILABLE:
else: else:
from mmcv.ops import SparseConvTensor, SparseMaxPool3d, SparseSequential from mmcv.ops import SparseConvTensor, SparseMaxPool3d, SparseSequential
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule, normal_init from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from mmengine.model.utils import normal_init
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.models.layers import nms_bev, nms_normal_bev
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
......
...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional ...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.ops import furthest_point_sample from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule
from mmengine import InstanceData from mmengine import InstanceData
from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment