Unverified Commit cbf194fa authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Enhance] GroupFree3d inherits BaseModule From MMCV (#704)

parent 26c18075
...@@ -2,10 +2,10 @@ import copy ...@@ -2,10 +2,10 @@ import copy
import numpy as np import numpy as np
import torch import torch
from mmcv import ConfigDict from mmcv import ConfigDict
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule, xavier_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding, from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer) build_transformer_layer)
from mmcv.runner import force_fp32 from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -19,7 +19,7 @@ from .base_conv_bbox_head import BaseConvBboxHead ...@@ -19,7 +19,7 @@ from .base_conv_bbox_head import BaseConvBboxHead
EPS = 1e-6 EPS = 1e-6
class PointsObjClsModule(nn.Module): class PointsObjClsModule(BaseModule):
"""object candidate point prediction from seed point features. """object candidate point prediction from seed point features.
Args: Args:
...@@ -39,8 +39,9 @@ class PointsObjClsModule(nn.Module): ...@@ -39,8 +39,9 @@ class PointsObjClsModule(nn.Module):
num_convs=3, num_convs=3,
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU')): act_cfg=dict(type='ReLU'),
super().__init__() init_cfg=None):
super().__init__(init_cfg=init_cfg)
conv_channels = [in_channel for _ in range(num_convs - 1)] conv_channels = [in_channel for _ in range(num_convs - 1)]
conv_channels.append(1) conv_channels.append(1)
...@@ -104,7 +105,7 @@ class GeneralSamplingModule(nn.Module): ...@@ -104,7 +105,7 @@ class GeneralSamplingModule(nn.Module):
@HEADS.register_module() @HEADS.register_module()
class GroupFree3DHead(nn.Module): class GroupFree3DHead(BaseModule):
r"""Bbox head of `Group-Free 3D <https://arxiv.org/abs/2104.00678>`_. r"""Bbox head of `Group-Free 3D <https://arxiv.org/abs/2104.00678>`_.
Args: Args:
...@@ -162,8 +163,9 @@ class GroupFree3DHead(nn.Module): ...@@ -162,8 +163,9 @@ class GroupFree3DHead(nn.Module):
size_class_loss=None, size_class_loss=None,
size_res_loss=None, size_res_loss=None,
size_reg_loss=None, size_reg_loss=None,
semantic_loss=None): semantic_loss=None,
super(GroupFree3DHead, self).__init__() init_cfg=None):
super(GroupFree3DHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -251,15 +253,13 @@ class GroupFree3DHead(nn.Module): ...@@ -251,15 +253,13 @@ class GroupFree3DHead(nn.Module):
# initialize transformer # initialize transformer
for m in self.decoder_layers.parameters(): for m in self.decoder_layers.parameters():
if m.dim() > 1: if m.dim() > 1:
nn.init.xavier_uniform_(m) xavier_init(m, distribution='uniform')
for m in self.decoder_self_posembeds.parameters(): for m in self.decoder_self_posembeds.parameters():
if m.dim() > 1: if m.dim() > 1:
nn.init.xavier_uniform_(m) xavier_init(m, distribution='uniform')
for m in self.decoder_cross_posembeds.parameters(): for m in self.decoder_cross_posembeds.parameters():
if m.dim() > 1: if m.dim() > 1:
nn.init.xavier_uniform_(m) xavier_init(m, distribution='uniform')
def _get_cls_out_channels(self): def _get_cls_out_channels(self):
"""Return the channel number of classification outputs.""" """Return the channel number of classification outputs."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment