Unverified Commit 84b132b0 authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

fix basemodule and init_weights (#1714)

parent 86f6183d
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn as nn
from torch.nn import functional as F
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn as nn
......
......@@ -2,9 +2,9 @@
import copy
import torch
from mmcv.cnn import (ConvModule, build_activation_layer, build_norm_layer,
constant_init)
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmcv.ops import assign_score_withk as assign_score_cuda
from mmengine.model.utils import constant_init
from torch import nn as nn
from torch.nn import functional as F
......
......@@ -4,7 +4,7 @@ from typing import List
import torch
from mmcv.cnn import ConvModule
from mmcv.ops import three_interpolate, three_nn
from mmcv.runner import BaseModule, force_fp32
from mmengine.model import BaseModule
from torch import nn as nn
......@@ -37,7 +37,6 @@ class PointFPModule(BaseModule):
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg))
@force_fp32()
def forward(self, target: torch.Tensor, source: torch.Tensor,
target_feats: torch.Tensor,
source_feats: torch.Tensor) -> torch.Tensor:
......
......@@ -8,7 +8,7 @@ if IS_SPCONV2_AVAILABLE:
else:
from mmcv.ops import SparseConvTensor, SparseSequential
from mmcv.runner import BaseModule, auto_fp16
from mmengine.model import BaseModule
from mmdet3d.models.layers import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.models.layers.sparse_block import replace_feature
......@@ -102,7 +102,6 @@ class SparseUNet(BaseModule):
indice_key='spconv_down2',
conv_type='SparseConv3d')
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUNet.
......
......@@ -3,7 +3,7 @@ import math
import numpy as np
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn as nn
from mmdet3d.registry import MODELS
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn as nn
from mmdet3d.models.layers.pointnet_modules import PointFPModule
......
......@@ -2,7 +2,7 @@
import numpy as np
import torch
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
from mmcv.runner import BaseModule, auto_fp16
from mmengine.model import BaseModule
from torch import nn as nn
from mmdet3d.registry import MODELS
......@@ -71,7 +71,6 @@ class SECONDFPN(BaseModule):
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)
]
@auto_fp16()
def forward(self, x):
"""Forward function.
......
......@@ -3,8 +3,8 @@ from typing import Dict, List, Optional, Tuple
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmengine import InstanceData
from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn
from torch.nn import functional as F
......
......@@ -3,8 +3,9 @@ from typing import List
import numpy as np
import torch
from mmcv.cnn import ConvModule, normal_init
from mmcv.cnn import ConvModule
from mmengine.data import InstanceData
from mmengine.model.utils import normal_init
from torch import Tensor
from mmdet3d.models import make_sparse_convmodule
......@@ -17,7 +18,7 @@ if IS_SPCONV2_AVAILABLE:
else:
from mmcv.ops import SparseConvTensor, SparseMaxPool3d, SparseSequential
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn as nn
from mmdet3d.models.builder import build_loss
......
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.cnn import ConvModule, normal_init
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from mmengine.model.utils import normal_init
from torch import nn as nn
from mmdet3d.models.layers import nms_bev, nms_normal_bev
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn as nn
from torch.nn import functional as F
......
......@@ -4,8 +4,8 @@ from typing import Dict, List, Optional
import torch
from mmcv.cnn import ConvModule
from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule
from mmengine import InstanceData
from mmengine.model import BaseModule
from torch import nn as nn
from torch.nn import functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment