Unverified Commit 26a16123 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

bugfixed: add **kwargs to each roi_head to support variable argument (#559)

parent 686cf446
...@@ -77,7 +77,8 @@ class Detector3DTemplate(nn.Module): ...@@ -77,7 +77,8 @@ class Detector3DTemplate(nn.Module):
) )
model_info_dict['module_list'].append(backbone_3d_module) model_info_dict['module_list'].append(backbone_3d_module)
model_info_dict['num_point_features'] = backbone_3d_module.num_point_features model_info_dict['num_point_features'] = backbone_3d_module.num_point_features
model_info_dict['backbone_channels'] = backbone_3d_module.backbone_channels model_info_dict['backbone_channels'] = backbone_3d_module.backbone_channels \
if hasattr(backbone_3d_module, 'backbone_channels') else None
return backbone_3d_module, model_info_dict return backbone_3d_module, model_info_dict
def build_map_to_bev_module(self, model_info_dict): def build_map_to_bev_module(self, model_info_dict):
......
...@@ -8,7 +8,7 @@ from .roi_head_template import RoIHeadTemplate ...@@ -8,7 +8,7 @@ from .roi_head_template import RoIHeadTemplate
class PartA2FCHead(RoIHeadTemplate): class PartA2FCHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1): def __init__(self, input_channels, model_cfg, num_class=1, **kwargs):
super().__init__(num_class=num_class, model_cfg=model_cfg) super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg self.model_cfg = model_cfg
......
...@@ -8,7 +8,7 @@ from .roi_head_template import RoIHeadTemplate ...@@ -8,7 +8,7 @@ from .roi_head_template import RoIHeadTemplate
class PointRCNNHead(RoIHeadTemplate): class PointRCNNHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1): def __init__(self, input_channels, model_cfg, num_class=1, **kwargs):
super().__init__(num_class=num_class, model_cfg=model_cfg) super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg self.model_cfg = model_cfg
use_bn = self.model_cfg.USE_BN use_bn = self.model_cfg.USE_BN
......
...@@ -6,7 +6,7 @@ from .roi_head_template import RoIHeadTemplate ...@@ -6,7 +6,7 @@ from .roi_head_template import RoIHeadTemplate
class PVRCNNHead(RoIHeadTemplate): class PVRCNNHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1): def __init__(self, input_channels, model_cfg, num_class=1, **kwargs):
super().__init__(num_class=num_class, model_cfg=model_cfg) super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg self.model_cfg = model_cfg
......
...@@ -9,7 +9,7 @@ from .target_assigner.proposal_target_layer import ProposalTargetLayer ...@@ -9,7 +9,7 @@ from .target_assigner.proposal_target_layer import ProposalTargetLayer
class RoIHeadTemplate(nn.Module): class RoIHeadTemplate(nn.Module):
def __init__(self, num_class, model_cfg): def __init__(self, num_class, model_cfg, **kwargs):
super().__init__() super().__init__()
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.num_class = num_class self.num_class = num_class
......
...@@ -5,7 +5,7 @@ from ...utils import common_utils, loss_utils ...@@ -5,7 +5,7 @@ from ...utils import common_utils, loss_utils
class SECONDHead(RoIHeadTemplate): class SECONDHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1): def __init__(self, input_channels, model_cfg, num_class=1, **kwargs):
super().__init__(num_class=num_class, model_cfg=model_cfg) super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg self.model_cfg = model_cfg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment