Commit 3b13794f authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 306768885
parent 795a3f7d
...@@ -14,8 +14,16 @@ ...@@ -14,8 +14,16 @@
# ============================================================================== # ==============================================================================
"""Base config template.""" """Base config template."""
# pylint: disable=line-too-long
BACKBONES = [
'resnet',
]
MULTILEVEL_FEATURES = [
'fpn',
]
# pylint: disable=line-too-long
# For ResNet, this freezes the variables of the first conv1 and conv2_x # For ResNet, this freezes the variables of the first conv1 and conv2_x
# layers [1], which leads to higher training speed and slightly better testing # layers [1], which leads to higher training speed and slightly better testing
# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50) # accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
...@@ -24,7 +32,6 @@ ...@@ -24,7 +32,6 @@
# Note that we need to trailing `/` to avoid the incorrect match. # Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198 # [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/' RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$' REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
BASE_CFG = { BASE_CFG = {
...@@ -41,6 +48,7 @@ BASE_CFG = { ...@@ -41,6 +48,7 @@ BASE_CFG = {
'optimizer': { 'optimizer': {
'type': 'momentum', 'type': 'momentum',
'momentum': 0.9, 'momentum': 0.9,
'nesterov': True, # `False` is better for TPU v3-128.
}, },
'learning_rate': { 'learning_rate': {
'type': 'step', 'type': 'step',
...@@ -49,21 +57,25 @@ BASE_CFG = { ...@@ -49,21 +57,25 @@ BASE_CFG = {
'init_learning_rate': 0.08, 'init_learning_rate': 0.08,
'learning_rate_levels': [0.008, 0.0008], 'learning_rate_levels': [0.008, 0.0008],
'learning_rate_steps': [15000, 20000], 'learning_rate_steps': [15000, 20000],
'total_steps': 22500,
}, },
'checkpoint': { 'checkpoint': {
'path': '', 'path': '',
'prefix': '', 'prefix': '',
}, },
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX, # One can use 'RESNET_FROZEN_VAR_PREFIX' to speed up ResNet training
# when loading from the checkpoint.
'frozen_variable_prefix': '',
'train_file_pattern': '', 'train_file_pattern': '',
'train_dataset_type': 'tfrecord', 'train_dataset_type': 'tfrecord',
# TODO(b/142174042): Support transpose_input option.
'transpose_input': False, 'transpose_input': False,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX, 'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001, 'l2_weight_decay': 0.0001,
'gradient_clip_norm': 0.0, 'gradient_clip_norm': 0.0,
'input_sharding': False,
}, },
'eval': { 'eval': {
'input_sharding': True,
'batch_size': 8, 'batch_size': 8,
'eval_samples': 5000, 'eval_samples': 5000,
'min_eval_interval': 180, 'min_eval_interval': 180,
...@@ -74,38 +86,42 @@ BASE_CFG = { ...@@ -74,38 +86,42 @@ BASE_CFG = {
'val_json_file': '', 'val_json_file': '',
'eval_file_pattern': '', 'eval_file_pattern': '',
'eval_dataset_type': 'tfrecord', 'eval_dataset_type': 'tfrecord',
# When visualizing images, set evaluation batch size to 40 to avoid
# potential OOM.
'num_images_to_visualize': 0,
}, },
'predict': { 'predict': {
'batch_size': 8, 'batch_size': 8,
}, },
'anchor': { 'architecture': {
'backbone': 'resnet',
'min_level': 3, 'min_level': 3,
'max_level': 7, 'max_level': 7,
'multilevel_features': 'fpn',
'use_bfloat16': True,
# Note that `num_classes` is the total number of classes including
# one background classes whose index is 0.
'num_classes': 91,
},
'anchor': {
'num_scales': 3, 'num_scales': 3,
'aspect_ratios': [1.0, 2.0, 0.5], 'aspect_ratios': [1.0, 2.0, 0.5],
'anchor_size': 4.0, 'anchor_size': 4.0,
}, },
'norm_activation': {
'activation': 'relu',
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
'resnet': { 'resnet': {
'resnet_depth': 50, 'resnet_depth': 50,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
}, },
'fpn': { 'fpn': {
'min_level': 3,
'max_level': 7,
'fpn_feat_dims': 256, 'fpn_feat_dims': 256,
'use_separable_conv': False, 'use_separable_conv': False,
'use_batch_norm': True, 'use_batch_norm': True,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
}, },
'postprocess': { 'postprocess': {
'use_batched_nms': False, 'use_batched_nms': False,
...@@ -116,5 +132,4 @@ BASE_CFG = { ...@@ -116,5 +132,4 @@ BASE_CFG = {
}, },
'enable_summary': False, 'enable_summary': False,
} }
# pylint: enable=line-too-long # pylint: enable=line-too-long
...@@ -28,13 +28,12 @@ MASKRCNN_CFG.override({ ...@@ -28,13 +28,12 @@ MASKRCNN_CFG.override({
}, },
'architecture': { 'architecture': {
'parser': 'maskrcnn_parser', 'parser': 'maskrcnn_parser',
'backbone': 'resnet', 'min_level': 2,
'multilevel_features': 'fpn', 'max_level': 6,
'use_bfloat16': True,
'include_mask': True, 'include_mask': True,
'mask_target_size': 28,
}, },
'maskrcnn_parser': { 'maskrcnn_parser': {
'use_bfloat16': True,
'output_size': [1024, 1024], 'output_size': [1024, 1024],
'num_channels': 3, 'num_channels': 3,
'rpn_match_threshold': 0.7, 'rpn_match_threshold': 0.7,
...@@ -46,74 +45,32 @@ MASKRCNN_CFG.override({ ...@@ -46,74 +45,32 @@ MASKRCNN_CFG.override({
'aug_scale_max': 1.0, 'aug_scale_max': 1.0,
'skip_crowd_during_training': True, 'skip_crowd_during_training': True,
'max_num_instances': 100, 'max_num_instances': 100,
'include_mask': True,
'mask_crop_size': 112, 'mask_crop_size': 112,
}, },
'anchor': { 'anchor': {
'min_level': 2,
'max_level': 6,
'num_scales': 1, 'num_scales': 1,
'anchor_size': 8, 'anchor_size': 8,
}, },
'fpn': {
'min_level': 2,
'max_level': 6,
},
'nasfpn': {
'min_level': 2,
'max_level': 6,
},
# tunable_nasfpn:strip_begin
'tunable_nasfpn_v1': {
'min_level': 2,
'max_level': 6,
},
# tunable_nasfpn:strip_end
'rpn_head': { 'rpn_head': {
'min_level': 2,
'max_level': 6,
'anchors_per_location': 3, 'anchors_per_location': 3,
'num_convs': 2, 'num_convs': 2,
'num_filters': 256, 'num_filters': 256,
'use_separable_conv': False, 'use_separable_conv': False,
'use_batch_norm': False, 'use_batch_norm': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
}, },
'frcnn_head': { 'frcnn_head': {
# Note that `num_classes` is the total number of classes including
# one background classes whose index is 0.
'num_classes': 91,
'num_convs': 0, 'num_convs': 0,
'num_filters': 256, 'num_filters': 256,
'use_separable_conv': False, 'use_separable_conv': False,
'num_fcs': 2, 'num_fcs': 2,
'fc_dims': 1024, 'fc_dims': 1024,
'use_batch_norm': False, 'use_batch_norm': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
}, },
'mrcnn_head': { 'mrcnn_head': {
'num_classes': 91,
'mask_target_size': 28,
'num_convs': 4, 'num_convs': 4,
'num_filters': 256, 'num_filters': 256,
'use_separable_conv': False, 'use_separable_conv': False,
'use_batch_norm': False, 'use_batch_norm': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
}, },
'rpn_score_loss': { 'rpn_score_loss': {
'rpn_batch_size_per_im': 256, 'rpn_batch_size_per_im': 256,
...@@ -147,23 +104,10 @@ MASKRCNN_CFG.override({ ...@@ -147,23 +104,10 @@ MASKRCNN_CFG.override({
}, },
'mask_sampling': { 'mask_sampling': {
'num_mask_samples_per_image': 128, # Typically = `num_samples_per_image` * `fg_fraction`. 'num_mask_samples_per_image': 128, # Typically = `num_samples_per_image` * `fg_fraction`.
'mask_target_size': 28,
},
'postprocess': {
'use_batched_nms': False,
'max_total_size': 100,
'nms_iou_threshold': 0.5,
'score_threshold': 0.05,
'pre_nms_num_boxes': 1000,
}, },
}, is_strict=False) }, is_strict=False)
MASKRCNN_RESTRICTIONS = [ MASKRCNN_RESTRICTIONS = [
'architecture.use_bfloat16 == maskrcnn_parser.use_bfloat16',
'architecture.include_mask == maskrcnn_parser.include_mask',
'anchor.min_level == rpn_head.min_level',
'anchor.max_level == rpn_head.max_level',
'mrcnn_head.mask_target_size == mask_sampling.mask_target_size',
] ]
# pylint: enable=line-too-long # pylint: enable=line-too-long
...@@ -14,84 +14,18 @@ ...@@ -14,84 +14,18 @@
# ============================================================================== # ==============================================================================
"""Config template to train Retinanet.""" """Config template to train Retinanet."""
# pylint: disable=line-too-long from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import base_config
# For ResNet-50, this freezes the variables of the first conv1 and conv2_x
# layers [1], which leads to higher training speed and slightly better testing
# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
# is able to capture low-level features such as edges; therefore, it does not
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
# pylint: disable=line-too-long # pylint: disable=line-too-long
RETINANET_CFG = { RETINANET_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
RETINANET_CFG.override({
'type': 'retinanet', 'type': 'retinanet',
'model_dir': '',
'use_tpu': True,
'strategy_type': 'tpu',
'train': {
'batch_size': 64,
'iterations_per_loop': 500,
'total_steps': 22500,
'optimizer': {
'type': 'momentum',
'momentum': 0.9,
'nesterov': True, # `False` is better for TPU v3-128.
},
'learning_rate': {
'type': 'step',
'warmup_learning_rate': 0.0067,
'warmup_steps': 500,
'init_learning_rate': 0.08,
'learning_rate_levels': [0.008, 0.0008],
'learning_rate_steps': [15000, 20000],
},
'checkpoint': {
'path': '',
'prefix': '',
},
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
'train_file_pattern': '',
# TODO(b/142174042): Support transpose_input option.
'transpose_input': False,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001,
'input_sharding': False,
},
'eval': {
'batch_size': 8,
'min_eval_interval': 180,
'eval_timeout': None,
'eval_samples': 5000,
'type': 'box',
'val_json_file': '',
'eval_file_pattern': '',
'input_sharding': True,
# When visualizing images, set evaluation batch size to 40 to avoid
# potential OOM.
'num_images_to_visualize': 0,
},
'predict': {
'predict_batch_size': 8,
},
'architecture': { 'architecture': {
'parser': 'retinanet_parser', 'parser': 'retinanet_parser',
'backbone': 'resnet',
'multilevel_features': 'fpn',
'use_bfloat16': False,
},
'anchor': {
'min_level': 3,
'max_level': 7,
'num_scales': 3,
'aspect_ratios': [1.0, 2.0, 0.5],
'anchor_size': 4.0,
}, },
'retinanet_parser': { 'retinanet_parser': {
'use_bfloat16': False,
'output_size': [640, 640], 'output_size': [640, 640],
'num_channels': 3, 'num_channels': 3,
'match_threshold': 0.5, 'match_threshold': 0.5,
...@@ -104,68 +38,22 @@ RETINANET_CFG = { ...@@ -104,68 +38,22 @@ RETINANET_CFG = {
'skip_crowd_during_training': True, 'skip_crowd_during_training': True,
'max_num_instances': 100, 'max_num_instances': 100,
}, },
'resnet': {
'resnet_depth': 50,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
},
'fpn': {
'min_level': 3,
'max_level': 7,
'fpn_feat_dims': 256,
'use_separable_conv': False,
'use_batch_norm': True,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
},
'retinanet_head': { 'retinanet_head': {
'min_level': 3,
'max_level': 7,
# Note that `num_classes` is the total number of classes including
# one background classes whose index is 0.
'num_classes': 91,
'anchors_per_location': 9, 'anchors_per_location': 9,
'retinanet_head_num_convs': 4, 'num_convs': 4,
'retinanet_head_num_filters': 256, 'num_filters': 256,
'use_separable_conv': False, 'use_separable_conv': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
}, },
'retinanet_loss': { 'retinanet_loss': {
'num_classes': 91,
'focal_loss_alpha': 0.25, 'focal_loss_alpha': 0.25,
'focal_loss_gamma': 1.5, 'focal_loss_gamma': 1.5,
'huber_loss_delta': 0.1, 'huber_loss_delta': 0.1,
'box_loss_weight': 50, 'box_loss_weight': 50,
}, },
'postprocess': { 'enable_summary': True,
'use_batched_nms': False, }, is_strict=False)
'min_level': 3,
'max_level': 7,
'max_total_size': 100,
'nms_iou_threshold': 0.5,
'score_threshold': 0.05,
'pre_nms_num_boxes': 5000,
},
'enable_summary': False,
}
RETINANET_RESTRICTIONS = [ RETINANET_RESTRICTIONS = [
'architecture.use_bfloat16 == retinanet_parser.use_bfloat16',
'anchor.min_level == retinanet_head.min_level',
'anchor.max_level == retinanet_head.max_level',
'anchor.min_level == postprocess.min_level',
'anchor.max_level == postprocess.max_level',
'retinanet_head.num_classes == retinanet_loss.num_classes',
] ]
# pylint: enable=line-too-long # pylint: enable=line-too-long
...@@ -29,8 +29,8 @@ def parser_generator(params, mode): ...@@ -29,8 +29,8 @@ def parser_generator(params, mode):
parser_params = params.retinanet_parser parser_params = params.retinanet_parser
parser_fn = retinanet_parser.Parser( parser_fn = retinanet_parser.Parser(
output_size=parser_params.output_size, output_size=parser_params.output_size,
min_level=anchor_params.min_level, min_level=params.architecture.min_level,
max_level=anchor_params.max_level, max_level=params.architecture.max_level,
num_scales=anchor_params.num_scales, num_scales=anchor_params.num_scales,
aspect_ratios=anchor_params.aspect_ratios, aspect_ratios=anchor_params.aspect_ratios,
anchor_size=anchor_params.anchor_size, anchor_size=anchor_params.anchor_size,
...@@ -43,15 +43,15 @@ def parser_generator(params, mode): ...@@ -43,15 +43,15 @@ def parser_generator(params, mode):
autoaugment_policy_name=parser_params.autoaugment_policy_name, autoaugment_policy_name=parser_params.autoaugment_policy_name,
skip_crowd_during_training=parser_params.skip_crowd_during_training, skip_crowd_during_training=parser_params.skip_crowd_during_training,
max_num_instances=parser_params.max_num_instances, max_num_instances=parser_params.max_num_instances,
use_bfloat16=parser_params.use_bfloat16, use_bfloat16=params.architecture.use_bfloat16,
mode=mode) mode=mode)
elif params.architecture.parser == 'maskrcnn_parser': elif params.architecture.parser == 'maskrcnn_parser':
anchor_params = params.anchor anchor_params = params.anchor
parser_params = params.maskrcnn_parser parser_params = params.maskrcnn_parser
parser_fn = maskrcnn_parser.Parser( parser_fn = maskrcnn_parser.Parser(
output_size=parser_params.output_size, output_size=parser_params.output_size,
min_level=anchor_params.min_level, min_level=params.architecture.min_level,
max_level=anchor_params.max_level, max_level=params.architecture.max_level,
num_scales=anchor_params.num_scales, num_scales=anchor_params.num_scales,
aspect_ratios=anchor_params.aspect_ratios, aspect_ratios=anchor_params.aspect_ratios,
anchor_size=anchor_params.anchor_size, anchor_size=anchor_params.anchor_size,
...@@ -64,17 +64,17 @@ def parser_generator(params, mode): ...@@ -64,17 +64,17 @@ def parser_generator(params, mode):
aug_scale_max=parser_params.aug_scale_max, aug_scale_max=parser_params.aug_scale_max,
skip_crowd_during_training=parser_params.skip_crowd_during_training, skip_crowd_during_training=parser_params.skip_crowd_during_training,
max_num_instances=parser_params.max_num_instances, max_num_instances=parser_params.max_num_instances,
include_mask=parser_params.include_mask, include_mask=params.architecture.include_mask,
mask_crop_size=parser_params.mask_crop_size, mask_crop_size=parser_params.mask_crop_size,
use_bfloat16=parser_params.use_bfloat16, use_bfloat16=params.architecture.use_bfloat16,
mode=mode) mode=mode)
elif params.architecture.parser == 'shapemask_parser': elif params.architecture.parser == 'shapemask_parser':
anchor_params = params.anchor anchor_params = params.anchor
parser_params = params.shapemask_parser parser_params = params.shapemask_parser
parser_fn = shapemask_parser.Parser( parser_fn = shapemask_parser.Parser(
output_size=parser_params.output_size, output_size=parser_params.output_size,
min_level=anchor_params.min_level, min_level=params.architecture.min_level,
max_level=anchor_params.max_level, max_level=params.architecture.max_level,
num_scales=anchor_params.num_scales, num_scales=anchor_params.num_scales,
aspect_ratios=anchor_params.aspect_ratios, aspect_ratios=anchor_params.aspect_ratios,
anchor_size=anchor_params.anchor_size, anchor_size=anchor_params.anchor_size,
...@@ -93,7 +93,7 @@ def parser_generator(params, mode): ...@@ -93,7 +93,7 @@ def parser_generator(params, mode):
aug_scale_max=parser_params.aug_scale_max, aug_scale_max=parser_params.aug_scale_max,
skip_crowd_during_training=parser_params.skip_crowd_during_training, skip_crowd_during_training=parser_params.skip_crowd_during_training,
max_num_instances=parser_params.max_num_instances, max_num_instances=parser_params.max_num_instances,
use_bfloat16=parser_params.use_bfloat16, use_bfloat16=params.architecture.use_bfloat16,
mask_train_class=parser_params.mask_train_class, mask_train_class=parser_params.mask_train_class,
mode=mode) mode=mode)
else: else:
......
...@@ -25,16 +25,12 @@ from official.vision.detection.modeling.architecture import nn_ops ...@@ -25,16 +25,12 @@ from official.vision.detection.modeling.architecture import nn_ops
from official.vision.detection.modeling.architecture import resnet from official.vision.detection.modeling.architecture import resnet
def batch_norm_relu_generator(params): def norm_activation_generator(params):
return nn_ops.norm_activation_builder(
def _batch_norm_op(**kwargs): momentum=params.batch_norm_momentum,
return nn_ops.BatchNormRelu( epsilon=params.batch_norm_epsilon,
momentum=params.batch_norm_momentum, trainable=params.batch_norm_trainable,
epsilon=params.batch_norm_epsilon, activation=params.activation)
trainable=params.batch_norm_trainable,
**kwargs)
return _batch_norm_op
def backbone_generator(params): def backbone_generator(params):
...@@ -43,10 +39,12 @@ def backbone_generator(params): ...@@ -43,10 +39,12 @@ def backbone_generator(params):
resnet_params = params.resnet resnet_params = params.resnet
backbone_fn = resnet.Resnet( backbone_fn = resnet.Resnet(
resnet_depth=resnet_params.resnet_depth, resnet_depth=resnet_params.resnet_depth,
batch_norm_relu=batch_norm_relu_generator(resnet_params.batch_norm)) activation=params.norm_activation.activation,
norm_activation=norm_activation_generator(
params.norm_activation))
else: else:
raise ValueError('Backbone model %s is not supported.' % raise ValueError('Backbone model `{}` is not supported.'
params.architecture.backbone) .format(params.architecture.backbone))
return backbone_fn return backbone_fn
...@@ -56,81 +54,75 @@ def multilevel_features_generator(params): ...@@ -56,81 +54,75 @@ def multilevel_features_generator(params):
if params.architecture.multilevel_features == 'fpn': if params.architecture.multilevel_features == 'fpn':
fpn_params = params.fpn fpn_params = params.fpn
fpn_fn = fpn.Fpn( fpn_fn = fpn.Fpn(
min_level=fpn_params.min_level, min_level=params.architecture.min_level,
max_level=fpn_params.max_level, max_level=params.architecture.max_level,
fpn_feat_dims=fpn_params.fpn_feat_dims, fpn_feat_dims=fpn_params.fpn_feat_dims,
use_separable_conv=fpn_params.use_separable_conv, use_separable_conv=fpn_params.use_separable_conv,
activation=params.norm_activation.activation,
use_batch_norm=fpn_params.use_batch_norm, use_batch_norm=fpn_params.use_batch_norm,
batch_norm_relu=batch_norm_relu_generator(fpn_params.batch_norm)) norm_activation=norm_activation_generator(
params.norm_activation))
elif params.architecture.multilevel_features == 'identity': elif params.architecture.multilevel_features == 'identity':
fpn_fn = identity.Identity() fpn_fn = identity.Identity()
else: else:
raise ValueError('The multi-level feature model %s is not supported.' raise ValueError('The multi-level feature model `{}` is not supported.'
% params.architecture.multilevel_features) .format(params.architecture.multilevel_features))
return fpn_fn return fpn_fn
def retinanet_head_generator(params): def retinanet_head_generator(params):
"""Generator function for RetinaNet head architecture.""" """Generator function for RetinaNet head architecture."""
head_params = params.retinanet_head
return heads.RetinanetHead( return heads.RetinanetHead(
params.min_level, params.architecture.min_level,
params.max_level, params.architecture.max_level,
params.num_classes, params.architecture.num_classes,
params.anchors_per_location, head_params.anchors_per_location,
params.retinanet_head_num_convs, head_params.num_convs,
params.retinanet_head_num_filters, head_params.num_filters,
params.use_separable_conv, head_params.use_separable_conv,
batch_norm_relu=batch_norm_relu_generator(params.batch_norm)) norm_activation=norm_activation_generator(params.norm_activation))
def rpn_head_generator(params): def rpn_head_generator(params):
head_params = params.rpn_head
"""Generator function for RPN head architecture.""" """Generator function for RPN head architecture."""
return heads.RpnHead(params.min_level, return heads.RpnHead(
params.max_level, params.architecture.min_level,
params.anchors_per_location, params.architecture.max_level,
params.num_convs, head_params.anchors_per_location,
params.num_filters, head_params.num_convs,
params.use_separable_conv, head_params.num_filters,
params.use_batch_norm, head_params.use_separable_conv,
batch_norm_relu=batch_norm_relu_generator( params.norm_activation.activation,
params.batch_norm)) head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def fast_rcnn_head_generator(params): def fast_rcnn_head_generator(params):
"""Generator function for Fast R-CNN head architecture.""" """Generator function for Fast R-CNN head architecture."""
return heads.FastrcnnHead(params.num_classes, head_params = params.frcnn_head
params.num_convs, return heads.FastrcnnHead(
params.num_filters, params.architecture.num_classes,
params.use_separable_conv, head_params.num_convs,
params.num_fcs, head_params.num_filters,
params.fc_dims, head_params.use_separable_conv,
params.use_batch_norm, head_params.num_fcs,
batch_norm_relu=batch_norm_relu_generator( head_params.fc_dims,
params.batch_norm)) params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def mask_rcnn_head_generator(params): def mask_rcnn_head_generator(params):
"""Generator function for Mask R-CNN head architecture.""" """Generator function for Mask R-CNN head architecture."""
return heads.MaskrcnnHead(params.num_classes, head_params = params.mrcnn_head
params.mask_target_size, return heads.MaskrcnnHead(
params.num_convs, params.architecture.num_classes,
params.num_filters, params.architecture.mask_target_size,
params.use_separable_conv, head_params.num_convs,
params.use_batch_norm, head_params.num_filters,
batch_norm_relu=batch_norm_relu_generator( head_params.use_separable_conv,
params.batch_norm)) params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def shapeprior_head_generator(params):
"""Generator function for Shapemask head architecture."""
raise NotImplementedError('Unimplemented')
def coarsemask_head_generator(params):
"""Generator function for Shapemask head architecture."""
raise NotImplementedError('Unimplemented')
def finemask_head_generator(params):
"""Generator function for Shapemask head architecture."""
raise NotImplementedError('Unimplemented')
...@@ -41,8 +41,10 @@ class Fpn(object): ...@@ -41,8 +41,10 @@ class Fpn(object):
max_level=7, max_level=7,
fpn_feat_dims=256, fpn_feat_dims=256,
use_separable_conv=False, use_separable_conv=False,
activation='relu',
use_batch_norm=True, use_batch_norm=True,
batch_norm_relu=nn_ops.BatchNormRelu): norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""FPN initialization function. """FPN initialization function.
Args: Args:
...@@ -52,8 +54,8 @@ class Fpn(object): ...@@ -52,8 +54,8 @@ class Fpn(object):
use_separable_conv: `bool`, if True use separable convolution for use_separable_conv: `bool`, if True use separable convolution for
convolution in FPN layers. convolution in FPN layers.
use_batch_norm: 'bool', indicating whether batchnorm layers are added. use_batch_norm: 'bool', indicating whether batchnorm layers are added.
batch_norm_relu: an operation that includes a batch normalization layer norm_activation: an operation that includes a normalization layer
followed by a relu layer(optional). followed by an optional activation layer.
""" """
self._min_level = min_level self._min_level = min_level
self._max_level = max_level self._max_level = max_level
...@@ -63,17 +65,23 @@ class Fpn(object): ...@@ -63,17 +65,23 @@ class Fpn(object):
tf.keras.layers.SeparableConv2D, depth_multiplier=1) tf.keras.layers.SeparableConv2D, depth_multiplier=1)
else: else:
self._conv2d_op = tf.keras.layers.Conv2D self._conv2d_op = tf.keras.layers.Conv2D
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm self._use_batch_norm = use_batch_norm
self._batch_norm_relu = batch_norm_relu self._norm_activation = norm_activation
self._batch_norm_relus = {} self._norm_activations = {}
self._lateral_conv2d_op = {} self._lateral_conv2d_op = {}
self._post_hoc_conv2d_op = {} self._post_hoc_conv2d_op = {}
self._coarse_conv2d_op = {} self._coarse_conv2d_op = {}
for level in range(self._min_level, self._max_level + 1): for level in range(self._min_level, self._max_level + 1):
if self._use_batch_norm: if self._use_batch_norm:
self._batch_norm_relus[level] = batch_norm_relu( self._norm_activations[level] = norm_activation(
relu=False, name='p%d-bn' % level) use_activation=False, name='p%d-bn' % level)
self._lateral_conv2d_op[level] = self._conv2d_op( self._lateral_conv2d_op[level] = self._conv2d_op(
filters=self._fpn_feat_dims, filters=self._fpn_feat_dims,
kernel_size=(1, 1), kernel_size=(1, 1),
...@@ -133,11 +141,11 @@ class Fpn(object): ...@@ -133,11 +141,11 @@ class Fpn(object):
for level in range(backbone_max_level + 1, self._max_level + 1): for level in range(backbone_max_level + 1, self._max_level + 1):
feats_in = feats[level - 1] feats_in = feats[level - 1]
if level > backbone_max_level + 1: if level > backbone_max_level + 1:
feats_in = tf.nn.relu(feats_in) feats_in = self._activation_op(feats_in)
feats[level] = self._coarse_conv2d_op[level](feats_in) feats[level] = self._coarse_conv2d_op[level](feats_in)
if self._use_batch_norm: if self._use_batch_norm:
# Adds batch_norm layer. # Adds batch_norm layer.
for level in range(self._min_level, self._max_level + 1): for level in range(self._min_level, self._max_level + 1):
feats[level] = self._batch_norm_relus[level]( feats[level] = self._norm_activations[level](
feats[level], is_training=is_training) feats[level], is_training=is_training)
return feats return feats
...@@ -39,8 +39,10 @@ class RpnHead(tf.keras.layers.Layer): ...@@ -39,8 +39,10 @@ class RpnHead(tf.keras.layers.Layer):
num_convs=2, num_convs=2,
num_filters=256, num_filters=256,
use_separable_conv=False, use_separable_conv=False,
activation='relu',
use_batch_norm=True, use_batch_norm=True,
batch_norm_relu=nn_ops.BatchNormRelu): norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""Initialize params to build Region Proposal Network head. """Initialize params to build Region Proposal Network head.
Args: Args:
...@@ -55,12 +57,18 @@ class RpnHead(tf.keras.layers.Layer): ...@@ -55,12 +57,18 @@ class RpnHead(tf.keras.layers.Layer):
use_separable_conv: `bool`, indicating whether the separable conv layers use_separable_conv: `bool`, indicating whether the separable conv layers
is used. is used.
use_batch_norm: 'bool', indicating whether batchnorm layers are added. use_batch_norm: 'bool', indicating whether batchnorm layers are added.
batch_norm_relu: an operation that includes a batch normalization layer norm_activation: an operation that includes a normalization layer
followed by a relu layer(optional). followed by an optional activation layer.
""" """
self._min_level = min_level self._min_level = min_level
self._max_level = max_level self._max_level = max_level
self._anchors_per_location = anchors_per_location self._anchors_per_location = anchors_per_location
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm self._use_batch_norm = use_batch_norm
if use_separable_conv: if use_separable_conv:
...@@ -78,7 +86,7 @@ class RpnHead(tf.keras.layers.Layer): ...@@ -78,7 +86,7 @@ class RpnHead(tf.keras.layers.Layer):
num_filters, num_filters,
kernel_size=(3, 3), kernel_size=(3, 3),
strides=(1, 1), strides=(1, 1),
activation=(None if self._use_batch_norm else tf.nn.relu), activation=(None if self._use_batch_norm else self._activation_op),
padding='same', padding='same',
name='rpn') name='rpn')
self._rpn_class_conv = self._conv2d_op( self._rpn_class_conv = self._conv2d_op(
...@@ -94,10 +102,10 @@ class RpnHead(tf.keras.layers.Layer): ...@@ -94,10 +102,10 @@ class RpnHead(tf.keras.layers.Layer):
padding='valid', padding='valid',
name='rpn-box') name='rpn-box')
self._batch_norm_relus = {} self._norm_activations = {}
if self._use_batch_norm: if self._use_batch_norm:
for level in range(self._min_level, self._max_level + 1): for level in range(self._min_level, self._max_level + 1):
self._batch_norm_relus[level] = batch_norm_relu(name='rpn-l%d-bn' % self._norm_activations[level] = norm_activation(name='rpn-l%d-bn' %
level) level)
def _shared_rpn_heads(self, features, anchors_per_location, level, def _shared_rpn_heads(self, features, anchors_per_location, level,
...@@ -106,7 +114,7 @@ class RpnHead(tf.keras.layers.Layer): ...@@ -106,7 +114,7 @@ class RpnHead(tf.keras.layers.Layer):
features = self._rpn_conv(features) features = self._rpn_conv(features)
if self._use_batch_norm: if self._use_batch_norm:
# The batch normalization layers are not shared between levels. # The batch normalization layers are not shared between levels.
features = self._batch_norm_relus[level]( features = self._norm_activations[level](
features, is_training=is_training) features, is_training=is_training)
# Proposal classification scores # Proposal classification scores
scores = self._rpn_class_conv(features) scores = self._rpn_class_conv(features)
...@@ -139,8 +147,10 @@ class FastrcnnHead(tf.keras.layers.Layer): ...@@ -139,8 +147,10 @@ class FastrcnnHead(tf.keras.layers.Layer):
use_separable_conv=False, use_separable_conv=False,
num_fcs=2, num_fcs=2,
fc_dims=1024, fc_dims=1024,
activation='relu',
use_batch_norm=True, use_batch_norm=True,
batch_norm_relu=nn_ops.BatchNormRelu): norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""Initialize params to build Fast R-CNN box head. """Initialize params to build Fast R-CNN box head.
Args: Args:
...@@ -156,8 +166,8 @@ class FastrcnnHead(tf.keras.layers.Layer): ...@@ -156,8 +166,8 @@ class FastrcnnHead(tf.keras.layers.Layer):
fc_dims: `int` number that represents the number of dimension of the FC fc_dims: `int` number that represents the number of dimension of the FC
layers. layers.
use_batch_norm: 'bool', indicating whether batchnorm layers are added. use_batch_norm: 'bool', indicating whether batchnorm layers are added.
batch_norm_relu: an operation that includes a batch normalization layer norm_activation: an operation that includes a normalization layer
followed by a relu layer(optional). followed by an optional activation layer.
""" """
self._num_classes = num_classes self._num_classes = num_classes
...@@ -177,9 +187,14 @@ class FastrcnnHead(tf.keras.layers.Layer): ...@@ -177,9 +187,14 @@ class FastrcnnHead(tf.keras.layers.Layer):
self._num_fcs = num_fcs self._num_fcs = num_fcs
self._fc_dims = fc_dims self._fc_dims = fc_dims
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm self._use_batch_norm = use_batch_norm
self._batch_norm_relu = batch_norm_relu self._norm_activation = norm_activation
self._conv_ops = [] self._conv_ops = []
self._conv_bn_ops = [] self._conv_bn_ops = []
...@@ -191,10 +206,10 @@ class FastrcnnHead(tf.keras.layers.Layer): ...@@ -191,10 +206,10 @@ class FastrcnnHead(tf.keras.layers.Layer):
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
dilation_rate=(1, 1), dilation_rate=(1, 1),
activation=(None if self._use_batch_norm else tf.nn.relu), activation=(None if self._use_batch_norm else self._activation_op),
name='conv_{}'.format(i))) name='conv_{}'.format(i)))
if self._use_batch_norm: if self._use_batch_norm:
self._conv_bn_ops.append(self._batch_norm_relu()) self._conv_bn_ops.append(self._norm_activation())
self._fc_ops = [] self._fc_ops = []
self._fc_bn_ops = [] self._fc_bn_ops = []
...@@ -202,10 +217,10 @@ class FastrcnnHead(tf.keras.layers.Layer): ...@@ -202,10 +217,10 @@ class FastrcnnHead(tf.keras.layers.Layer):
self._fc_ops.append( self._fc_ops.append(
tf.keras.layers.Dense( tf.keras.layers.Dense(
units=self._fc_dims, units=self._fc_dims,
activation=(None if self._use_batch_norm else tf.nn.relu), activation=(None if self._use_batch_norm else self._activation_op),
name='fc{}'.format(i))) name='fc{}'.format(i)))
if self._use_batch_norm: if self._use_batch_norm:
self._fc_bn_ops.append(self._batch_norm_relu(fused=False)) self._fc_bn_ops.append(self._norm_activation(fused=False))
self._class_predict = tf.keras.layers.Dense( self._class_predict = tf.keras.layers.Dense(
self._num_classes, self._num_classes,
...@@ -266,8 +281,10 @@ class MaskrcnnHead(tf.keras.layers.Layer): ...@@ -266,8 +281,10 @@ class MaskrcnnHead(tf.keras.layers.Layer):
num_convs=4, num_convs=4,
num_filters=256, num_filters=256,
use_separable_conv=False, use_separable_conv=False,
activation='relu',
use_batch_norm=True, use_batch_norm=True,
batch_norm_relu=nn_ops.BatchNormRelu): norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""Initialize params to build Fast R-CNN head. """Initialize params to build Fast R-CNN head.
Args: Args:
...@@ -280,8 +297,8 @@ class MaskrcnnHead(tf.keras.layers.Layer): ...@@ -280,8 +297,8 @@ class MaskrcnnHead(tf.keras.layers.Layer):
use_separable_conv: `bool`, indicating whether the separable conv layers use_separable_conv: `bool`, indicating whether the separable conv layers
is used. is used.
use_batch_norm: 'bool', indicating whether batchnorm layers are added. use_batch_norm: 'bool', indicating whether batchnorm layers are added.
batch_norm_relu: an operation that includes a batch normalization layer norm_activation: an operation that includes a normalization layer
followed by a relu layer(optional). followed by an optional activation layer.
""" """
self._num_classes = num_classes self._num_classes = num_classes
self._mask_target_size = mask_target_size self._mask_target_size = mask_target_size
...@@ -299,9 +316,14 @@ class MaskrcnnHead(tf.keras.layers.Layer): ...@@ -299,9 +316,14 @@ class MaskrcnnHead(tf.keras.layers.Layer):
kernel_initializer=tf.keras.initializers.VarianceScaling( kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'), scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer()) bias_initializer=tf.zeros_initializer())
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm self._use_batch_norm = use_batch_norm
self._batch_norm_relu = batch_norm_relu self._norm_activation = norm_activation
self._conv2d_ops = [] self._conv2d_ops = []
for i in range(self._num_convs): for i in range(self._num_convs):
self._conv2d_ops.append( self._conv2d_ops.append(
...@@ -311,14 +333,14 @@ class MaskrcnnHead(tf.keras.layers.Layer): ...@@ -311,14 +333,14 @@ class MaskrcnnHead(tf.keras.layers.Layer):
strides=(1, 1), strides=(1, 1),
padding='same', padding='same',
dilation_rate=(1, 1), dilation_rate=(1, 1),
activation=(None if self._use_batch_norm else tf.nn.relu), activation=(None if self._use_batch_norm else self._activation_op),
name='mask-conv-l%d' % i)) name='mask-conv-l%d' % i))
self._mask_conv_transpose = tf.keras.layers.Conv2DTranspose( self._mask_conv_transpose = tf.keras.layers.Conv2DTranspose(
self._num_filters, self._num_filters,
kernel_size=(2, 2), kernel_size=(2, 2),
strides=(2, 2), strides=(2, 2),
padding='valid', padding='valid',
activation=(None if self._use_batch_norm else tf.nn.relu), activation=(None if self._use_batch_norm else self._activation_op),
kernel_initializer=tf.keras.initializers.VarianceScaling( kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'), scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer(), bias_initializer=tf.zeros_initializer(),
...@@ -353,11 +375,11 @@ class MaskrcnnHead(tf.keras.layers.Layer): ...@@ -353,11 +375,11 @@ class MaskrcnnHead(tf.keras.layers.Layer):
for i in range(self._num_convs): for i in range(self._num_convs):
net = self._conv2d_ops[i](net) net = self._conv2d_ops[i](net)
if self._use_batch_norm: if self._use_batch_norm:
net = self._batch_norm_relu()(net, is_training=is_training) net = self._norm_activation()(net, is_training=is_training)
net = self._mask_conv_transpose(net) net = self._mask_conv_transpose(net)
if self._use_batch_norm: if self._use_batch_norm:
net = self._batch_norm_relu()(net, is_training=is_training) net = self._norm_activation()(net, is_training=is_training)
mask_outputs = self._conv2d_op( mask_outputs = self._conv2d_op(
self._num_classes, self._num_classes,
...@@ -398,7 +420,8 @@ class RetinanetHead(object): ...@@ -398,7 +420,8 @@ class RetinanetHead(object):
num_convs=4, num_convs=4,
num_filters=256, num_filters=256,
use_separable_conv=False, use_separable_conv=False,
batch_norm_relu=nn_ops.BatchNormRelu): norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""Initialize params to build RetinaNet head. """Initialize params to build RetinaNet head.
Args: Args:
...@@ -411,8 +434,8 @@ class RetinanetHead(object): ...@@ -411,8 +434,8 @@ class RetinanetHead(object):
num_filters: `int` number of filters used in the head architecture. num_filters: `int` number of filters used in the head architecture.
use_separable_conv: `bool` to indicate whether to use separable use_separable_conv: `bool` to indicate whether to use separable
convoluation. convoluation.
batch_norm_relu: an operation that includes a batch normalization layer norm_activation: an operation that includes a normalization layer
followed by a relu layer(optional). followed by an optional activation layer.
""" """
self._min_level = min_level self._min_level = min_level
self._max_level = max_level self._max_level = max_level
...@@ -423,13 +446,12 @@ class RetinanetHead(object): ...@@ -423,13 +446,12 @@ class RetinanetHead(object):
self._num_convs = num_convs self._num_convs = num_convs
self._num_filters = num_filters self._num_filters = num_filters
self._use_separable_conv = use_separable_conv self._use_separable_conv = use_separable_conv
with tf.name_scope('class_net') as scope_name: with tf.name_scope('class_net') as scope_name:
self._class_name_scope = tf.name_scope(scope_name) self._class_name_scope = tf.name_scope(scope_name)
with tf.name_scope('box_net') as scope_name: with tf.name_scope('box_net') as scope_name:
self._box_name_scope = tf.name_scope(scope_name) self._box_name_scope = tf.name_scope(scope_name)
self._build_class_net_layers(batch_norm_relu) self._build_class_net_layers(norm_activation)
self._build_box_net_layers(batch_norm_relu) self._build_box_net_layers(norm_activation)
def _class_net_batch_norm_name(self, i, level): def _class_net_batch_norm_name(self, i, level):
return 'class-%d-%d' % (i, level) return 'class-%d-%d' % (i, level)
...@@ -437,7 +459,7 @@ class RetinanetHead(object): ...@@ -437,7 +459,7 @@ class RetinanetHead(object):
def _box_net_batch_norm_name(self, i, level): def _box_net_batch_norm_name(self, i, level):
return 'box-%d-%d' % (i, level) return 'box-%d-%d' % (i, level)
def _build_class_net_layers(self, batch_norm_relu): def _build_class_net_layers(self, norm_activation):
"""Build re-usable layers for class prediction network.""" """Build re-usable layers for class prediction network."""
if self._use_separable_conv: if self._use_separable_conv:
self._class_predict = tf.keras.layers.SeparableConv2D( self._class_predict = tf.keras.layers.SeparableConv2D(
...@@ -455,7 +477,7 @@ class RetinanetHead(object): ...@@ -455,7 +477,7 @@ class RetinanetHead(object):
padding='same', padding='same',
name='class-predict') name='class-predict')
self._class_conv = [] self._class_conv = []
self._class_batch_norm_relu = {} self._class_norm_activation = {}
for i in range(self._num_convs): for i in range(self._num_convs):
if self._use_separable_conv: if self._use_separable_conv:
self._class_conv.append( self._class_conv.append(
...@@ -479,9 +501,9 @@ class RetinanetHead(object): ...@@ -479,9 +501,9 @@ class RetinanetHead(object):
name='class-' + str(i))) name='class-' + str(i)))
for level in range(self._min_level, self._max_level + 1): for level in range(self._min_level, self._max_level + 1):
name = self._class_net_batch_norm_name(i, level) name = self._class_net_batch_norm_name(i, level)
self._class_batch_norm_relu[name] = batch_norm_relu(name=name) self._class_norm_activation[name] = norm_activation(name=name)
def _build_box_net_layers(self, batch_norm_relu): def _build_box_net_layers(self, norm_activation):
"""Build re-usable layers for box prediction network.""" """Build re-usable layers for box prediction network."""
if self._use_separable_conv: if self._use_separable_conv:
self._box_predict = tf.keras.layers.SeparableConv2D( self._box_predict = tf.keras.layers.SeparableConv2D(
...@@ -499,7 +521,7 @@ class RetinanetHead(object): ...@@ -499,7 +521,7 @@ class RetinanetHead(object):
padding='same', padding='same',
name='box-predict') name='box-predict')
self._box_conv = [] self._box_conv = []
self._box_batch_norm_relu = {} self._box_norm_activation = {}
for i in range(self._num_convs): for i in range(self._num_convs):
if self._use_separable_conv: if self._use_separable_conv:
self._box_conv.append( self._box_conv.append(
...@@ -523,13 +545,13 @@ class RetinanetHead(object): ...@@ -523,13 +545,13 @@ class RetinanetHead(object):
name='box-' + str(i))) name='box-' + str(i)))
for level in range(self._min_level, self._max_level + 1): for level in range(self._min_level, self._max_level + 1):
name = self._box_net_batch_norm_name(i, level) name = self._box_net_batch_norm_name(i, level)
self._box_batch_norm_relu[name] = batch_norm_relu(name=name) self._box_norm_activation[name] = norm_activation(name=name)
def __call__(self, fpn_features, is_training=None): def __call__(self, fpn_features, is_training=None):
"""Returns outputs of RetinaNet head.""" """Returns outputs of RetinaNet head."""
class_outputs = {} class_outputs = {}
box_outputs = {} box_outputs = {}
with backend.get_graph().as_default(), tf.name_scope('retinanet'): with backend.get_graph().as_default(), tf.name_scope('retinanet_head'):
for level in range(self._min_level, self._max_level + 1): for level in range(self._min_level, self._max_level + 1):
features = fpn_features[level] features = fpn_features[level]
...@@ -548,7 +570,7 @@ class RetinanetHead(object): ...@@ -548,7 +570,7 @@ class RetinanetHead(object):
# each level has its batch normlization to capture the statistical # each level has its batch normlization to capture the statistical
# difference among different levels. # difference among different levels.
name = self._class_net_batch_norm_name(i, level) name = self._class_net_batch_norm_name(i, level)
features = self._class_batch_norm_relu[name]( features = self._class_norm_activation[name](
features, is_training=is_training) features, is_training=is_training)
classes = self._class_predict(features) classes = self._class_predict(features)
...@@ -563,7 +585,7 @@ class RetinanetHead(object): ...@@ -563,7 +585,7 @@ class RetinanetHead(object):
# each level has its batch normlization to capture the statistical # each level has its batch normlization to capture the statistical
# difference among different levels. # difference among different levels.
name = self._box_net_batch_norm_name(i, level) name = self._box_net_batch_norm_name(i, level)
features = self._box_batch_norm_relu[name]( features = self._box_norm_activation[name](
features, is_training=is_training) features, is_training=is_training)
boxes = self._box_predict(features) boxes = self._box_predict(features)
...@@ -953,13 +975,13 @@ class ShapemaskCoarsemaskHead(object): ...@@ -953,13 +975,13 @@ class ShapemaskCoarsemaskHead(object):
def coarsemask_decoder_net(self, def coarsemask_decoder_net(self,
images, images,
is_training=None, is_training=None,
batch_norm_relu=nn_ops.BatchNormRelu): norm_activation=nn_ops.norm_activation_builder()):
"""Coarse mask decoder network architecture. """Coarse mask decoder network architecture.
Args: Args:
images: A tensor of size [batch, height_in, width_in, channels_in]. images: A tensor of size [batch, height_in, width_in, channels_in].
is_training: Whether batch_norm layers are in training mode. is_training: Whether batch_norm layers are in training mode.
batch_norm_relu: an operation that includes a batch normalization layer norm_activation: an operation that includes a batch normalization layer
followed by a relu layer(optional). followed by a relu layer(optional).
Returns: Returns:
images: A feature tensor of size [batch, output_size, output_size, images: A feature tensor of size [batch, output_size, output_size,
...@@ -975,7 +997,7 @@ class ShapemaskCoarsemaskHead(object): ...@@ -975,7 +997,7 @@ class ShapemaskCoarsemaskHead(object):
padding='same', padding='same',
name='coarse-class-%d' % i)( name='coarse-class-%d' % i)(
images) images)
images = batch_norm_relu(name='coarse-class-%d-bn' % i)( images = norm_activation(name='coarse-class-%d-bn' % i)(
images, is_training=is_training) images, is_training=is_training)
return images return images
...@@ -991,7 +1013,7 @@ class ShapemaskFinemaskHead(object): ...@@ -991,7 +1013,7 @@ class ShapemaskFinemaskHead(object):
num_convs, num_convs,
coarse_mask_thr, coarse_mask_thr,
gt_upsample_scale, gt_upsample_scale,
batch_norm_relu=nn_ops.BatchNormRelu): norm_activation=nn_ops.norm_activation_builder()):
"""Initialize params to build ShapeMask coarse and fine prediction head. """Initialize params to build ShapeMask coarse and fine prediction head.
Args: Args:
...@@ -1002,7 +1024,7 @@ class ShapemaskFinemaskHead(object): ...@@ -1002,7 +1024,7 @@ class ShapemaskFinemaskHead(object):
layer. layer.
coarse_mask_thr: the threshold for suppressing noisy coarse prediction. coarse_mask_thr: the threshold for suppressing noisy coarse prediction.
gt_upsample_scale: scale for upsampling groundtruths. gt_upsample_scale: scale for upsampling groundtruths.
batch_norm_relu: an operation that includes a batch normalization layer norm_activation: an operation that includes a batch normalization layer
followed by a relu layer(optional). followed by a relu layer(optional).
""" """
self._mask_num_classes = num_classes self._mask_num_classes = num_classes
...@@ -1038,7 +1060,7 @@ class ShapemaskFinemaskHead(object): ...@@ -1038,7 +1060,7 @@ class ShapemaskFinemaskHead(object):
activation=None, activation=None,
padding='same', padding='same',
name='fine-class-%d' % i)) name='fine-class-%d' % i))
self._fine_class_bn.append(batch_norm_relu(name='fine-class-%d-bn' % i)) self._fine_class_bn.append(norm_activation(name='fine-class-%d-bn' % i))
def __call__(self, prior_conditioned_features, class_probs, is_training=None): def __call__(self, prior_conditioned_features, class_probs, is_training=None):
"""Generate instance masks from FPN features and detection priors. """Generate instance masks from FPN features and detection priors.
......
...@@ -18,20 +18,21 @@ from __future__ import absolute_import ...@@ -18,20 +18,21 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
from absl import logging from absl import logging
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from tensorflow.python.keras import backend
class BatchNormRelu(tf.keras.layers.Layer): class NormActivation(tf.keras.layers.Layer):
"""Combined Batch Normalization and ReLU layers.""" """Combined Normalization and Activation layers."""
def __init__(self, def __init__(self,
momentum=0.997, momentum=0.997,
epsilon=1e-4, epsilon=1e-4,
trainable=True, trainable=True,
relu=True,
init_zero=False, init_zero=False,
use_activation=True,
activation='relu',
fused=True, fused=True,
name=None): name=None):
"""A class to construct layers for a batch normalization followed by a ReLU. """A class to construct layers for a batch normalization followed by a ReLU.
...@@ -39,22 +40,24 @@ class BatchNormRelu(tf.keras.layers.Layer): ...@@ -39,22 +40,24 @@ class BatchNormRelu(tf.keras.layers.Layer):
Args: Args:
momentum: momentum for the moving average. momentum: momentum for the moving average.
epsilon: small float added to variance to avoid dividing by zero. epsilon: small float added to variance to avoid dividing by zero.
trainable: `boolean`, if True also add variables to the graph collection trainable: `bool`, if True also add variables to the graph collection
GraphKeys.TRAINABLE_VARIABLES. If False, freeze batch normalization GraphKeys.TRAINABLE_VARIABLES. If False, freeze batch normalization
layer. layer.
relu: `bool` if False, omits the ReLU operation.
init_zero: `bool` if True, initializes scale parameter of batch init_zero: `bool` if True, initializes scale parameter of batch
normalization with 0. If False, initialize it with 1. normalization with 0. If False, initialize it with 1.
fused: `bool` fused option in batch normalziation. fused: `bool` fused option in batch normalziation.
use_actiation: `bool`, whether to add the optional activation layer after
the batch normalization layer.
activation: 'string', the type of the activation layer. Currently support
`relu` and `swish`.
name: `str` name for the operation. name: `str` name for the operation.
""" """
super(BatchNormRelu, self).__init__(trainable=trainable) super(NormActivation, self).__init__(trainable=trainable)
self._use_relu = relu
if init_zero: if init_zero:
gamma_initializer = tf.keras.initializers.Zeros() gamma_initializer = tf.keras.initializers.Zeros()
else: else:
gamma_initializer = tf.keras.initializers.Ones() gamma_initializer = tf.keras.initializers.Ones()
self._batch_norm_op = tf.keras.layers.BatchNormalization( self._normalization_op = tf.keras.layers.BatchNormalization(
momentum=momentum, momentum=momentum,
epsilon=epsilon, epsilon=epsilon,
center=True, center=True,
...@@ -63,9 +66,16 @@ class BatchNormRelu(tf.keras.layers.Layer): ...@@ -63,9 +66,16 @@ class BatchNormRelu(tf.keras.layers.Layer):
fused=fused, fused=fused,
gamma_initializer=gamma_initializer, gamma_initializer=gamma_initializer,
name=name) name=name)
self._use_activation = use_activation
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
def __call__(self, inputs, is_training=None): def __call__(self, inputs, is_training=None):
"""Builds layers for a batch normalization followed by a ReLU. """Builds the normalization layer followed by an optional activation layer.
Args: Args:
inputs: `Tensor` of shape `[batch, channels, ...]`. inputs: `Tensor` of shape `[batch, channels, ...]`.
...@@ -78,9 +88,22 @@ class BatchNormRelu(tf.keras.layers.Layer): ...@@ -78,9 +88,22 @@ class BatchNormRelu(tf.keras.layers.Layer):
# from keras.Model.training # from keras.Model.training
if is_training and self.trainable: if is_training and self.trainable:
is_training = True is_training = True
inputs = self._batch_norm_op(inputs, training=is_training) inputs = self._normalization_op(inputs, training=is_training)
if self._use_relu: if self._use_activation:
inputs = tf.nn.relu(inputs) inputs = self._activation_op(inputs)
return inputs return inputs
def norm_activation_builder(momentum=0.997,
epsilon=1e-4,
trainable=True,
activation='relu',
**kwargs):
return functools.partial(
NormActivation,
momentum=momentum,
epsilon=epsilon,
trainable=trainable,
activation='relu',
**kwargs)
...@@ -34,21 +34,27 @@ class Resnet(object): ...@@ -34,21 +34,27 @@ class Resnet(object):
def __init__(self, def __init__(self,
resnet_depth, resnet_depth,
batch_norm_relu=nn_ops.BatchNormRelu, activation='relu',
norm_activation=nn_ops.norm_activation_builder(
activation='relu'),
data_format='channels_last'): data_format='channels_last'):
"""ResNet initialization function. """ResNet initialization function.
Args: Args:
resnet_depth: `int` depth of ResNet backbone model. resnet_depth: `int` depth of ResNet backbone model.
batch_norm_relu: an operation that includes a batch normalization layer norm_activation: an operation that includes a normalization layer
followed by a relu layer(optional). followed by an optional activation layer.
data_format: `str` either "channels_first" for `[batch, channels, height, data_format: `str` either "channels_first" for `[batch, channels, height,
width]` or "channels_last for `[batch, height, width, channels]`. width]` or "channels_last for `[batch, height, width, channels]`.
""" """
self._resnet_depth = resnet_depth self._resnet_depth = resnet_depth
if activation == 'relu':
self._batch_norm_relu = batch_norm_relu self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._norm_activation = norm_activation
self._data_format = data_format self._data_format = data_format
model_params = { model_params = {
...@@ -170,19 +176,19 @@ class Resnet(object): ...@@ -170,19 +176,19 @@ class Resnet(object):
# Projection shortcut in first layer to match filters and strides # Projection shortcut in first layer to match filters and strides
shortcut = self.conv2d_fixed_padding( shortcut = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=strides) inputs=inputs, filters=filters, kernel_size=1, strides=strides)
shortcut = self._batch_norm_relu(relu=False)( shortcut = self._norm_activation(use_activation=False)(
shortcut, is_training=is_training) shortcut, is_training=is_training)
inputs = self.conv2d_fixed_padding( inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides) inputs=inputs, filters=filters, kernel_size=3, strides=strides)
inputs = self._batch_norm_relu()(inputs, is_training=is_training) inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = self.conv2d_fixed_padding( inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=1) inputs=inputs, filters=filters, kernel_size=3, strides=1)
inputs = self._batch_norm_relu()( inputs = self._norm_activation(use_activation=False, init_zero=True)(
inputs, relu=False, init_zero=True, is_training=is_training) inputs, is_training=is_training)
return tf.nn.relu(inputs + shortcut) return self._activation_op(inputs + shortcut)
def bottleneck_block(self, def bottleneck_block(self,
inputs, inputs,
...@@ -214,24 +220,23 @@ class Resnet(object): ...@@ -214,24 +220,23 @@ class Resnet(object):
filters_out = 4 * filters filters_out = 4 * filters
shortcut = self.conv2d_fixed_padding( shortcut = self.conv2d_fixed_padding(
inputs=inputs, filters=filters_out, kernel_size=1, strides=strides) inputs=inputs, filters=filters_out, kernel_size=1, strides=strides)
shortcut = self._batch_norm_relu(relu=False)( shortcut = self._norm_activation(use_activation=False)(
shortcut, is_training=is_training) shortcut, is_training=is_training)
inputs = self.conv2d_fixed_padding( inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=1) inputs=inputs, filters=filters, kernel_size=1, strides=1)
inputs = self._batch_norm_relu()(inputs, is_training=is_training) inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = self.conv2d_fixed_padding( inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides) inputs=inputs, filters=filters, kernel_size=3, strides=strides)
inputs = self._batch_norm_relu()(inputs, is_training=is_training) inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = self.conv2d_fixed_padding( inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=4 * filters, kernel_size=1, strides=1) inputs=inputs, filters=4 * filters, kernel_size=1, strides=1)
inputs = self._batch_norm_relu( inputs = self._norm_activation(use_activation=False, init_zero=True)(
relu=False, init_zero=True)( inputs, is_training=is_training)
inputs, is_training=is_training)
return tf.nn.relu(inputs + shortcut) return self._activation_op(inputs + shortcut)
def block_group(self, inputs, filters, block_fn, blocks, strides, name, def block_group(self, inputs, filters, block_fn, blocks, strides, name,
is_training): is_training):
...@@ -279,7 +284,7 @@ class Resnet(object): ...@@ -279,7 +284,7 @@ class Resnet(object):
inputs = self.conv2d_fixed_padding( inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=64, kernel_size=7, strides=2) inputs=inputs, filters=64, kernel_size=7, strides=2)
inputs = tf.identity(inputs, 'initial_conv') inputs = tf.identity(inputs, 'initial_conv')
inputs = self._batch_norm_relu()(inputs, is_training=is_training) inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = tf.keras.layers.MaxPool2D( inputs = tf.keras.layers.MaxPool2D(
pool_size=3, strides=2, padding='SAME', pool_size=3, strides=2, padding='SAME',
......
...@@ -43,7 +43,8 @@ def _make_filter_trainable_variables_fn(frozen_variable_prefix): ...@@ -43,7 +43,8 @@ def _make_filter_trainable_variables_fn(frozen_variable_prefix):
# the frozen variables' names. # the frozen variables' names.
filtered_variables = [ filtered_variables = [
v for v in variables v for v in variables
if not re.match(frozen_variable_prefix, v.name) if not frozen_variable_prefix or
not re.match(frozen_variable_prefix, v.name)
] ]
return filtered_variables return filtered_variables
...@@ -66,7 +67,7 @@ class Model(object): ...@@ -66,7 +67,7 @@ class Model(object):
# Optimization. # Optimization.
self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer) self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer)
self._learning_rate = learning_rates.learning_rate_generator( self._learning_rate = learning_rates.learning_rate_generator(
params.train.learning_rate) params.train.total_steps, params.train.learning_rate)
self._frozen_variable_prefix = params.train.frozen_variable_prefix self._frozen_variable_prefix = params.train.frozen_variable_prefix
self._regularization_var_regex = params.train.regularization_variable_regex self._regularization_var_regex = params.train.regularization_variable_regex
......
...@@ -28,9 +28,10 @@ from official.modeling.hyperparams import params_dict ...@@ -28,9 +28,10 @@ from official.modeling.hyperparams import params_dict
class StepLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule): class StepLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor.""" """Class to generate learning rate tensor."""
def __init__(self, params): def __init__(self, total_steps, params):
"""Creates the step learning rate tensor with linear warmup.""" """Creates the step learning rate tensor with linear warmup."""
super(StepLearningRateWithLinearWarmup, self).__init__() super(StepLearningRateWithLinearWarmup, self).__init__()
self._total_steps = total_steps
assert isinstance(params, (dict, params_dict.ParamsDict)) assert isinstance(params, (dict, params_dict.ParamsDict))
if isinstance(params, dict): if isinstance(params, dict):
params = params_dict.ParamsDict(params) params = params_dict.ParamsDict(params)
...@@ -59,9 +60,10 @@ class StepLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRat ...@@ -59,9 +60,10 @@ class StepLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRat
class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule): class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor.""" """Class to generate learning rate tensor."""
def __init__(self, params): def __init__(self, total_steps, params):
"""Creates the consine learning rate tensor with linear warmup.""" """Creates the consine learning rate tensor with linear warmup."""
super(CosineLearningRateWithLinearWarmup, self).__init__() super(CosineLearningRateWithLinearWarmup, self).__init__()
self._total_steps = total_steps
assert isinstance(params, (dict, params_dict.ParamsDict)) assert isinstance(params, (dict, params_dict.ParamsDict))
if isinstance(params, dict): if isinstance(params, dict):
params = params_dict.ParamsDict(params) params = params_dict.ParamsDict(params)
...@@ -72,7 +74,7 @@ class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningR ...@@ -72,7 +74,7 @@ class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningR
warmup_lr = self._params.warmup_learning_rate warmup_lr = self._params.warmup_learning_rate
warmup_steps = self._params.warmup_steps warmup_steps = self._params.warmup_steps
init_lr = self._params.init_learning_rate init_lr = self._params.init_learning_rate
total_steps = self._params.total_steps total_steps = self._total_steps
linear_warmup = ( linear_warmup = (
warmup_lr + global_step / warmup_steps * (init_lr - warmup_lr)) warmup_lr + global_step / warmup_steps * (init_lr - warmup_lr))
cosine_learning_rate = ( cosine_learning_rate = (
...@@ -86,11 +88,11 @@ class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningR ...@@ -86,11 +88,11 @@ class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningR
return {'_params': self._params.as_dict()} return {'_params': self._params.as_dict()}
def learning_rate_generator(params): def learning_rate_generator(total_steps, params):
"""The learning rate function generator.""" """The learning rate function generator."""
if params.type == 'step': if params.type == 'step':
return StepLearningRateWithLinearWarmup(params) return StepLearningRateWithLinearWarmup(total_steps, params)
elif params.type == 'cosine': elif params.type == 'cosine':
return CosineLearningRateWithLinearWarmup(params) return CosineLearningRateWithLinearWarmup(total_steps, params)
else: else:
raise ValueError('Unsupported learning rate type: {}.'.format(params.type)) raise ValueError('Unsupported learning rate type: {}.'.format(params.type))
...@@ -371,8 +371,8 @@ class MaskrcnnLoss(object): ...@@ -371,8 +371,8 @@ class MaskrcnnLoss(object):
class RetinanetClassLoss(object): class RetinanetClassLoss(object):
"""RetinaNet class loss.""" """RetinaNet class loss."""
def __init__(self, params): def __init__(self, params, num_classes):
self._num_classes = params.num_classes self._num_classes = num_classes
self._focal_loss_alpha = params.focal_loss_alpha self._focal_loss_alpha = params.focal_loss_alpha
self._focal_loss_gamma = params.focal_loss_gamma self._focal_loss_gamma = params.focal_loss_gamma
......
...@@ -49,14 +49,16 @@ class MaskrcnnModel(base_model.Model): ...@@ -49,14 +49,16 @@ class MaskrcnnModel(base_model.Model):
# Architecture generators. # Architecture generators.
self._backbone_fn = factory.backbone_generator(params) self._backbone_fn = factory.backbone_generator(params)
self._fpn_fn = factory.multilevel_features_generator(params) self._fpn_fn = factory.multilevel_features_generator(params)
self._rpn_head_fn = factory.rpn_head_generator(params.rpn_head) self._rpn_head_fn = factory.rpn_head_generator(params)
self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal) self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
self._sample_rois_fn = sampling_ops.ROISampler(params.roi_sampling) self._sample_rois_fn = sampling_ops.ROISampler(params.roi_sampling)
self._sample_masks_fn = sampling_ops.MaskSampler(params.mask_sampling) self._sample_masks_fn = sampling_ops.MaskSampler(
params.architecture.mask_target_size,
params.mask_sampling.num_mask_samples_per_image)
self._frcnn_head_fn = factory.fast_rcnn_head_generator(params.frcnn_head) self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
if self._include_mask: if self._include_mask:
self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params.mrcnn_head) self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)
# Loss function. # Loss function.
self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss) self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
...@@ -91,8 +93,8 @@ class MaskrcnnModel(base_model.Model): ...@@ -91,8 +93,8 @@ class MaskrcnnModel(base_model.Model):
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_box_outputs), rpn_box_outputs),
}) })
input_anchor = anchor.Anchor(self._params.anchor.min_level, input_anchor = anchor.Anchor(self._params.architecture.min_level,
self._params.anchor.max_level, self._params.architecture.max_level,
self._params.anchor.num_scales, self._params.anchor.num_scales,
self._params.anchor.aspect_ratios, self._params.anchor.aspect_ratios,
self._params.anchor.anchor_size, self._params.anchor.anchor_size,
......
...@@ -30,15 +30,10 @@ class OptimizerFactory(object): ...@@ -30,15 +30,10 @@ class OptimizerFactory(object):
def __init__(self, params): def __init__(self, params):
"""Creates optimized based on the specified flags.""" """Creates optimized based on the specified flags."""
if params.type == 'momentum': if params.type == 'momentum':
nesterov = False
try:
nesterov = params.nesterov
except AttributeError:
pass
self._optimizer = functools.partial( self._optimizer = functools.partial(
tf.keras.optimizers.SGD, tf.keras.optimizers.SGD,
momentum=params.momentum, momentum=params.momentum,
nesterov=nesterov) nesterov=params.nesterov)
elif params.type == 'adam': elif params.type == 'adam':
self._optimizer = tf.keras.optimizers.Adam self._optimizer = tf.keras.optimizers.Adam
elif params.type == 'adadelta': elif params.type == 'adadelta':
......
...@@ -44,16 +44,19 @@ class RetinanetModel(base_model.Model): ...@@ -44,16 +44,19 @@ class RetinanetModel(base_model.Model):
# Architecture generators. # Architecture generators.
self._backbone_fn = factory.backbone_generator(params) self._backbone_fn = factory.backbone_generator(params)
self._fpn_fn = factory.multilevel_features_generator(params) self._fpn_fn = factory.multilevel_features_generator(params)
self._head_fn = factory.retinanet_head_generator(params.retinanet_head) self._head_fn = factory.retinanet_head_generator(params)
# Loss function. # Loss function.
self._cls_loss_fn = losses.RetinanetClassLoss(params.retinanet_loss) self._cls_loss_fn = losses.RetinanetClassLoss(
params.retinanet_loss, params.architecture.num_classes)
self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss) self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
self._box_loss_weight = params.retinanet_loss.box_loss_weight self._box_loss_weight = params.retinanet_loss.box_loss_weight
self._keras_model = None self._keras_model = None
# Predict function. # Predict function.
self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator( self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
params.architecture.min_level,
params.architecture.max_level,
params.postprocess) params.postprocess)
self._transpose_input = params.train.transpose_input self._transpose_input = params.train.transpose_input
......
...@@ -294,10 +294,10 @@ def _generate_detections_batched(boxes, ...@@ -294,10 +294,10 @@ def _generate_detections_batched(boxes,
class MultilevelDetectionGenerator(object): class MultilevelDetectionGenerator(object):
"""Generates detected boxes with scores and classes for one-stage detector.""" """Generates detected boxes with scores and classes for one-stage detector."""
def __init__(self, params): def __init__(self, min_level, max_level, params):
self._min_level = min_level
self._max_level = max_level
self._generate_detections = generate_detections_factory(params) self._generate_detections = generate_detections_factory(params)
self._min_level = params.min_level
self._max_level = params.max_level
def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape): def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
# Collects outputs from all levels into a list. # Collects outputs from all levels into a list.
......
...@@ -346,9 +346,9 @@ class ROISampler(object): ...@@ -346,9 +346,9 @@ class ROISampler(object):
class MaskSampler(object): class MaskSampler(object):
"""Samples and creates mask training targets.""" """Samples and creates mask training targets."""
def __init__(self, params): def __init__(self, mask_target_size, num_mask_samples_per_image):
self._num_mask_samples_per_image = params.num_mask_samples_per_image self._mask_target_size = mask_target_size
self._mask_target_size = params.mask_target_size self._num_mask_samples_per_image = num_mask_samples_per_image
def __call__(self, def __call__(self,
candidate_rois, candidate_rois,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment