Commit 52496f55 authored by Kai Chen's avatar Kai Chen
Browse files

rename *RoIHead to *BBoxHead

parent bac11303
......@@ -20,7 +20,7 @@ model = dict(
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCRoIHead',
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
......
......@@ -20,7 +20,7 @@ model = dict(
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCRoIHead',
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
......
......@@ -30,7 +30,7 @@ model = dict(
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCRoIHead',
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
......
......@@ -30,7 +30,7 @@ model = dict(
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCRoIHead',
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
......
from .bbox_head import BBoxHead
from .convfc_bbox_head import ConvFCRoIHead, SharedFCRoIHead
from .convfc_bbox_head import ConvFCBBoxHead, SharedFCBBoxHead
__all__ = ['BBoxHead', 'ConvFCRoIHead', 'SharedFCRoIHead']
__all__ = ['BBoxHead', 'ConvFCBBoxHead', 'SharedFCBBoxHead']
......@@ -4,7 +4,7 @@ from .bbox_head import BBoxHead
from ..utils import ConvModule
class ConvFCRoIHead(BBoxHead):
class ConvFCBBoxHead(BBoxHead):
"""More general bbox head, with shared conv and fc layers and two optional
separated branches.
......@@ -24,7 +24,7 @@ class ConvFCRoIHead(BBoxHead):
fc_out_channels=1024,
*args,
**kwargs):
super(ConvFCRoIHead, self).__init__(*args, **kwargs)
super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
assert (num_shared_convs + num_shared_fcs + num_cls_convs + num_cls_fcs
+ num_reg_convs + num_reg_fcs > 0)
if num_cls_convs > 0 or num_reg_convs > 0:
......@@ -116,7 +116,7 @@ class ConvFCRoIHead(BBoxHead):
return branch_convs, branch_fcs, last_layer_dim
def init_weights(self):
super(ConvFCRoIHead, self).init_weights()
super(ConvFCBBoxHead, self).init_weights()
for module_list in [self.shared_fcs, self.cls_fcs, self.reg_fcs]:
for m in module_list.modules():
if isinstance(m, nn.Linear):
......@@ -162,11 +162,11 @@ class ConvFCRoIHead(BBoxHead):
return cls_score, bbox_pred
class SharedFCRoIHead(ConvFCRoIHead):
class SharedFCBBoxHead(ConvFCBBoxHead):
def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
assert num_fcs >= 1
super(SharedFCRoIHead, self).__init__(
super(SharedFCBBoxHead, self).__init__(
num_shared_convs=0,
num_shared_fcs=num_fcs,
num_cls_convs=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment