Commit d192f78a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

For object detection models, removing rpn_head.anchors_per_location from...

For object detection models, removing rpn_head.anchors_per_location from config and using anchor.num_scales * len(anchor.aspect_ratios) to compute anchors_per_location

PiperOrigin-RevId: 324878153
parent 739a7f32
...@@ -52,7 +52,6 @@ MASKRCNN_CFG.override({ ...@@ -52,7 +52,6 @@ MASKRCNN_CFG.override({
'anchor_size': 8, 'anchor_size': 8,
}, },
'rpn_head': { 'rpn_head': {
'anchors_per_location': 3,
'num_convs': 2, 'num_convs': 2,
'num_filters': 256, 'num_filters': 256,
'use_separable_conv': False, 'use_separable_conv': False,
......
...@@ -39,7 +39,6 @@ RETINANET_CFG.override({ ...@@ -39,7 +39,6 @@ RETINANET_CFG.override({
'max_num_instances': 100, 'max_num_instances': 100,
}, },
'retinanet_head': { 'retinanet_head': {
'anchors_per_location': 9,
'num_convs': 4, 'num_convs': 4,
'num_filters': 256, 'num_filters': 256,
'use_separable_conv': False, 'use_separable_conv': False,
......
...@@ -62,7 +62,6 @@ SHAPEMASK_CFG.override({ ...@@ -62,7 +62,6 @@ SHAPEMASK_CFG.override({
'upsample_factor': 4, 'upsample_factor': 4,
}, },
'retinanet_head': { 'retinanet_head': {
'anchors_per_location': 9,
'num_convs': 4, 'num_convs': 4,
'num_filters': 256, 'num_filters': 256,
'use_separable_conv': False, 'use_separable_conv': False,
......
...@@ -77,11 +77,13 @@ def multilevel_features_generator(params): ...@@ -77,11 +77,13 @@ def multilevel_features_generator(params):
def retinanet_head_generator(params): def retinanet_head_generator(params):
"""Generator function for RetinaNet head architecture.""" """Generator function for RetinaNet head architecture."""
head_params = params.retinanet_head head_params = params.retinanet_head
anchors_per_location = params.anchor.num_scales * len(
params.anchor.aspect_ratios)
return heads.RetinanetHead( return heads.RetinanetHead(
params.architecture.min_level, params.architecture.min_level,
params.architecture.max_level, params.architecture.max_level,
params.architecture.num_classes, params.architecture.num_classes,
head_params.anchors_per_location, anchors_per_location,
head_params.num_convs, head_params.num_convs,
head_params.num_filters, head_params.num_filters,
head_params.use_separable_conv, head_params.use_separable_conv,
...@@ -91,10 +93,12 @@ def retinanet_head_generator(params): ...@@ -91,10 +93,12 @@ def retinanet_head_generator(params):
def rpn_head_generator(params): def rpn_head_generator(params):
"""Generator function for RPN head architecture.""" """Generator function for RPN head architecture."""
head_params = params.rpn_head head_params = params.rpn_head
anchors_per_location = params.anchor.num_scales * len(
params.anchor.aspect_ratios)
return heads.RpnHead( return heads.RpnHead(
params.architecture.min_level, params.architecture.min_level,
params.architecture.max_level, params.architecture.max_level,
head_params.anchors_per_location, anchors_per_location,
head_params.num_convs, head_params.num_convs,
head_params.num_filters, head_params.num_filters,
head_params.use_separable_conv, head_params.use_separable_conv,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment