Unverified Commit cbf194fa authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Enhance] GroupFree3d inherits BaseModule From MMCV (#704)

parent 26c18075
......@@ -2,10 +2,10 @@ import copy
import numpy as np
import torch
from mmcv import ConfigDict
from mmcv.cnn import ConvModule
from mmcv.cnn import ConvModule, xavier_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer)
from mmcv.runner import force_fp32
from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn
from torch.nn import functional as F
......@@ -19,7 +19,7 @@ from .base_conv_bbox_head import BaseConvBboxHead
EPS = 1e-6
class PointsObjClsModule(nn.Module):
class PointsObjClsModule(BaseModule):
"""object candidate point prediction from seed point features.
Args:
......@@ -39,8 +39,9 @@ class PointsObjClsModule(nn.Module):
num_convs=3,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU')):
super().__init__()
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
conv_channels = [in_channel for _ in range(num_convs - 1)]
conv_channels.append(1)
......@@ -104,7 +105,7 @@ class GeneralSamplingModule(nn.Module):
@HEADS.register_module()
class GroupFree3DHead(nn.Module):
class GroupFree3DHead(BaseModule):
r"""Bbox head of `Group-Free 3D <https://arxiv.org/abs/2104.00678>`_.
Args:
......@@ -162,8 +163,9 @@ class GroupFree3DHead(nn.Module):
size_class_loss=None,
size_res_loss=None,
size_reg_loss=None,
semantic_loss=None):
super(GroupFree3DHead, self).__init__()
semantic_loss=None,
init_cfg=None):
super(GroupFree3DHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......@@ -251,15 +253,13 @@ class GroupFree3DHead(nn.Module):
# initialize transformer
for m in self.decoder_layers.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)
xavier_init(m, distribution='uniform')
for m in self.decoder_self_posembeds.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)
xavier_init(m, distribution='uniform')
for m in self.decoder_cross_posembeds.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)
xavier_init(m, distribution='uniform')
def _get_cls_out_channels(self):
"""Return the channel number of classification outputs."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment