Commit 3b13794f authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 306768885
parent 795a3f7d
......@@ -14,8 +14,16 @@
# ==============================================================================
"""Base config template."""
# pylint: disable=line-too-long
BACKBONES = [
'resnet',
]
MULTILEVEL_FEATURES = [
'fpn',
]
# pylint: disable=line-too-long
# For ResNet, this freezes the variables of the first conv1 and conv2_x
# layers [1], which leads to higher training speed and slightly better testing
# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
......@@ -24,7 +32,6 @@
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
BASE_CFG = {
......@@ -41,6 +48,7 @@ BASE_CFG = {
'optimizer': {
'type': 'momentum',
'momentum': 0.9,
'nesterov': True, # `False` is better for TPU v3-128.
},
'learning_rate': {
'type': 'step',
......@@ -49,21 +57,25 @@ BASE_CFG = {
'init_learning_rate': 0.08,
'learning_rate_levels': [0.008, 0.0008],
'learning_rate_steps': [15000, 20000],
'total_steps': 22500,
},
'checkpoint': {
'path': '',
'prefix': '',
},
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
# One can use 'RESNET_FROZEN_VAR_PREFIX' to speed up ResNet training
# when loading from the checkpoint.
'frozen_variable_prefix': '',
'train_file_pattern': '',
'train_dataset_type': 'tfrecord',
# TODO(b/142174042): Support transpose_input option.
'transpose_input': False,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001,
'gradient_clip_norm': 0.0,
'input_sharding': False,
},
'eval': {
'input_sharding': True,
'batch_size': 8,
'eval_samples': 5000,
'min_eval_interval': 180,
......@@ -74,38 +86,42 @@ BASE_CFG = {
'val_json_file': '',
'eval_file_pattern': '',
'eval_dataset_type': 'tfrecord',
# When visualizing images, set evaluation batch size to 40 to avoid
# potential OOM.
'num_images_to_visualize': 0,
},
'predict': {
'batch_size': 8,
},
'anchor': {
'architecture': {
'backbone': 'resnet',
'min_level': 3,
'max_level': 7,
'multilevel_features': 'fpn',
'use_bfloat16': True,
# Note that `num_classes` is the total number of classes including
# one background classes whose index is 0.
'num_classes': 91,
},
'anchor': {
'num_scales': 3,
'aspect_ratios': [1.0, 2.0, 0.5],
'anchor_size': 4.0,
},
'resnet': {
'resnet_depth': 50,
'batch_norm': {
'norm_activation': {
'activation': 'relu',
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
'resnet': {
'resnet_depth': 50,
},
'fpn': {
'min_level': 3,
'max_level': 7,
'fpn_feat_dims': 256,
'use_separable_conv': False,
'use_batch_norm': True,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
'postprocess': {
'use_batched_nms': False,
......@@ -116,5 +132,4 @@ BASE_CFG = {
},
'enable_summary': False,
}
# pylint: enable=line-too-long
......@@ -28,13 +28,12 @@ MASKRCNN_CFG.override({
},
'architecture': {
'parser': 'maskrcnn_parser',
'backbone': 'resnet',
'multilevel_features': 'fpn',
'use_bfloat16': True,
'min_level': 2,
'max_level': 6,
'include_mask': True,
'mask_target_size': 28,
},
'maskrcnn_parser': {
'use_bfloat16': True,
'output_size': [1024, 1024],
'num_channels': 3,
'rpn_match_threshold': 0.7,
......@@ -46,74 +45,32 @@ MASKRCNN_CFG.override({
'aug_scale_max': 1.0,
'skip_crowd_during_training': True,
'max_num_instances': 100,
'include_mask': True,
'mask_crop_size': 112,
},
'anchor': {
'min_level': 2,
'max_level': 6,
'num_scales': 1,
'anchor_size': 8,
},
'fpn': {
'min_level': 2,
'max_level': 6,
},
'nasfpn': {
'min_level': 2,
'max_level': 6,
},
# tunable_nasfpn:strip_begin
'tunable_nasfpn_v1': {
'min_level': 2,
'max_level': 6,
},
# tunable_nasfpn:strip_end
'rpn_head': {
'min_level': 2,
'max_level': 6,
'anchors_per_location': 3,
'num_convs': 2,
'num_filters': 256,
'use_separable_conv': False,
'use_batch_norm': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
'frcnn_head': {
# Note that `num_classes` is the total number of classes including
# one background classes whose index is 0.
'num_classes': 91,
'num_convs': 0,
'num_filters': 256,
'use_separable_conv': False,
'num_fcs': 2,
'fc_dims': 1024,
'use_batch_norm': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
'mrcnn_head': {
'num_classes': 91,
'mask_target_size': 28,
'num_convs': 4,
'num_filters': 256,
'use_separable_conv': False,
'use_batch_norm': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
'rpn_score_loss': {
'rpn_batch_size_per_im': 256,
......@@ -147,23 +104,10 @@ MASKRCNN_CFG.override({
},
'mask_sampling': {
'num_mask_samples_per_image': 128, # Typically = `num_samples_per_image` * `fg_fraction`.
'mask_target_size': 28,
},
'postprocess': {
'use_batched_nms': False,
'max_total_size': 100,
'nms_iou_threshold': 0.5,
'score_threshold': 0.05,
'pre_nms_num_boxes': 1000,
},
}, is_strict=False)
MASKRCNN_RESTRICTIONS = [
'architecture.use_bfloat16 == maskrcnn_parser.use_bfloat16',
'architecture.include_mask == maskrcnn_parser.include_mask',
'anchor.min_level == rpn_head.min_level',
'anchor.max_level == rpn_head.max_level',
'mrcnn_head.mask_target_size == mask_sampling.mask_target_size',
]
# pylint: enable=line-too-long
......@@ -14,84 +14,18 @@
# ==============================================================================
"""Config template to train Retinanet."""
# pylint: disable=line-too-long
from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import base_config
# For ResNet-50, this freezes the variables of the first conv1 and conv2_x
# layers [1], which leads to higher training speed and slightly better testing
# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
# is able to capture low-level features such as edges; therefore, it does not
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
# pylint: disable=line-too-long
RETINANET_CFG = {
RETINANET_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
RETINANET_CFG.override({
'type': 'retinanet',
'model_dir': '',
'use_tpu': True,
'strategy_type': 'tpu',
'train': {
'batch_size': 64,
'iterations_per_loop': 500,
'total_steps': 22500,
'optimizer': {
'type': 'momentum',
'momentum': 0.9,
'nesterov': True, # `False` is better for TPU v3-128.
},
'learning_rate': {
'type': 'step',
'warmup_learning_rate': 0.0067,
'warmup_steps': 500,
'init_learning_rate': 0.08,
'learning_rate_levels': [0.008, 0.0008],
'learning_rate_steps': [15000, 20000],
},
'checkpoint': {
'path': '',
'prefix': '',
},
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
'train_file_pattern': '',
# TODO(b/142174042): Support transpose_input option.
'transpose_input': False,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001,
'input_sharding': False,
},
'eval': {
'batch_size': 8,
'min_eval_interval': 180,
'eval_timeout': None,
'eval_samples': 5000,
'type': 'box',
'val_json_file': '',
'eval_file_pattern': '',
'input_sharding': True,
# When visualizing images, set evaluation batch size to 40 to avoid
# potential OOM.
'num_images_to_visualize': 0,
},
'predict': {
'predict_batch_size': 8,
},
'architecture': {
'parser': 'retinanet_parser',
'backbone': 'resnet',
'multilevel_features': 'fpn',
'use_bfloat16': False,
},
'anchor': {
'min_level': 3,
'max_level': 7,
'num_scales': 3,
'aspect_ratios': [1.0, 2.0, 0.5],
'anchor_size': 4.0,
},
'retinanet_parser': {
'use_bfloat16': False,
'output_size': [640, 640],
'num_channels': 3,
'match_threshold': 0.5,
......@@ -104,68 +38,22 @@ RETINANET_CFG = {
'skip_crowd_during_training': True,
'max_num_instances': 100,
},
'resnet': {
'resnet_depth': 50,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
},
'fpn': {
'min_level': 3,
'max_level': 7,
'fpn_feat_dims': 256,
'use_separable_conv': False,
'use_batch_norm': True,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
},
'retinanet_head': {
'min_level': 3,
'max_level': 7,
# Note that `num_classes` is the total number of classes including
# one background classes whose index is 0.
'num_classes': 91,
'anchors_per_location': 9,
'retinanet_head_num_convs': 4,
'retinanet_head_num_filters': 256,
'num_convs': 4,
'num_filters': 256,
'use_separable_conv': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
},
},
'retinanet_loss': {
'num_classes': 91,
'focal_loss_alpha': 0.25,
'focal_loss_gamma': 1.5,
'huber_loss_delta': 0.1,
'box_loss_weight': 50,
},
'postprocess': {
'use_batched_nms': False,
'min_level': 3,
'max_level': 7,
'max_total_size': 100,
'nms_iou_threshold': 0.5,
'score_threshold': 0.05,
'pre_nms_num_boxes': 5000,
},
'enable_summary': False,
}
'enable_summary': True,
}, is_strict=False)
RETINANET_RESTRICTIONS = [
'architecture.use_bfloat16 == retinanet_parser.use_bfloat16',
'anchor.min_level == retinanet_head.min_level',
'anchor.max_level == retinanet_head.max_level',
'anchor.min_level == postprocess.min_level',
'anchor.max_level == postprocess.max_level',
'retinanet_head.num_classes == retinanet_loss.num_classes',
]
# pylint: enable=line-too-long
......@@ -29,8 +29,8 @@ def parser_generator(params, mode):
parser_params = params.retinanet_parser
parser_fn = retinanet_parser.Parser(
output_size=parser_params.output_size,
min_level=anchor_params.min_level,
max_level=anchor_params.max_level,
min_level=params.architecture.min_level,
max_level=params.architecture.max_level,
num_scales=anchor_params.num_scales,
aspect_ratios=anchor_params.aspect_ratios,
anchor_size=anchor_params.anchor_size,
......@@ -43,15 +43,15 @@ def parser_generator(params, mode):
autoaugment_policy_name=parser_params.autoaugment_policy_name,
skip_crowd_during_training=parser_params.skip_crowd_during_training,
max_num_instances=parser_params.max_num_instances,
use_bfloat16=parser_params.use_bfloat16,
use_bfloat16=params.architecture.use_bfloat16,
mode=mode)
elif params.architecture.parser == 'maskrcnn_parser':
anchor_params = params.anchor
parser_params = params.maskrcnn_parser
parser_fn = maskrcnn_parser.Parser(
output_size=parser_params.output_size,
min_level=anchor_params.min_level,
max_level=anchor_params.max_level,
min_level=params.architecture.min_level,
max_level=params.architecture.max_level,
num_scales=anchor_params.num_scales,
aspect_ratios=anchor_params.aspect_ratios,
anchor_size=anchor_params.anchor_size,
......@@ -64,17 +64,17 @@ def parser_generator(params, mode):
aug_scale_max=parser_params.aug_scale_max,
skip_crowd_during_training=parser_params.skip_crowd_during_training,
max_num_instances=parser_params.max_num_instances,
include_mask=parser_params.include_mask,
include_mask=params.architecture.include_mask,
mask_crop_size=parser_params.mask_crop_size,
use_bfloat16=parser_params.use_bfloat16,
use_bfloat16=params.architecture.use_bfloat16,
mode=mode)
elif params.architecture.parser == 'shapemask_parser':
anchor_params = params.anchor
parser_params = params.shapemask_parser
parser_fn = shapemask_parser.Parser(
output_size=parser_params.output_size,
min_level=anchor_params.min_level,
max_level=anchor_params.max_level,
min_level=params.architecture.min_level,
max_level=params.architecture.max_level,
num_scales=anchor_params.num_scales,
aspect_ratios=anchor_params.aspect_ratios,
anchor_size=anchor_params.anchor_size,
......@@ -93,7 +93,7 @@ def parser_generator(params, mode):
aug_scale_max=parser_params.aug_scale_max,
skip_crowd_during_training=parser_params.skip_crowd_during_training,
max_num_instances=parser_params.max_num_instances,
use_bfloat16=parser_params.use_bfloat16,
use_bfloat16=params.architecture.use_bfloat16,
mask_train_class=parser_params.mask_train_class,
mode=mode)
else:
......
......@@ -25,16 +25,12 @@ from official.vision.detection.modeling.architecture import nn_ops
from official.vision.detection.modeling.architecture import resnet
def batch_norm_relu_generator(params):
def _batch_norm_op(**kwargs):
return nn_ops.BatchNormRelu(
def norm_activation_generator(params):
return nn_ops.norm_activation_builder(
momentum=params.batch_norm_momentum,
epsilon=params.batch_norm_epsilon,
trainable=params.batch_norm_trainable,
**kwargs)
return _batch_norm_op
activation=params.activation)
def backbone_generator(params):
......@@ -43,10 +39,12 @@ def backbone_generator(params):
resnet_params = params.resnet
backbone_fn = resnet.Resnet(
resnet_depth=resnet_params.resnet_depth,
batch_norm_relu=batch_norm_relu_generator(resnet_params.batch_norm))
activation=params.norm_activation.activation,
norm_activation=norm_activation_generator(
params.norm_activation))
else:
raise ValueError('Backbone model %s is not supported.' %
params.architecture.backbone)
raise ValueError('Backbone model `{}` is not supported.'
.format(params.architecture.backbone))
return backbone_fn
......@@ -56,81 +54,75 @@ def multilevel_features_generator(params):
if params.architecture.multilevel_features == 'fpn':
fpn_params = params.fpn
fpn_fn = fpn.Fpn(
min_level=fpn_params.min_level,
max_level=fpn_params.max_level,
min_level=params.architecture.min_level,
max_level=params.architecture.max_level,
fpn_feat_dims=fpn_params.fpn_feat_dims,
use_separable_conv=fpn_params.use_separable_conv,
activation=params.norm_activation.activation,
use_batch_norm=fpn_params.use_batch_norm,
batch_norm_relu=batch_norm_relu_generator(fpn_params.batch_norm))
norm_activation=norm_activation_generator(
params.norm_activation))
elif params.architecture.multilevel_features == 'identity':
fpn_fn = identity.Identity()
else:
raise ValueError('The multi-level feature model %s is not supported.'
% params.architecture.multilevel_features)
raise ValueError('The multi-level feature model `{}` is not supported.'
.format(params.architecture.multilevel_features))
return fpn_fn
def retinanet_head_generator(params):
"""Generator function for RetinaNet head architecture."""
head_params = params.retinanet_head
return heads.RetinanetHead(
params.min_level,
params.max_level,
params.num_classes,
params.anchors_per_location,
params.retinanet_head_num_convs,
params.retinanet_head_num_filters,
params.use_separable_conv,
batch_norm_relu=batch_norm_relu_generator(params.batch_norm))
params.architecture.min_level,
params.architecture.max_level,
params.architecture.num_classes,
head_params.anchors_per_location,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
norm_activation=norm_activation_generator(params.norm_activation))
def rpn_head_generator(params):
head_params = params.rpn_head
"""Generator function for RPN head architecture."""
return heads.RpnHead(params.min_level,
params.max_level,
params.anchors_per_location,
params.num_convs,
params.num_filters,
params.use_separable_conv,
params.use_batch_norm,
batch_norm_relu=batch_norm_relu_generator(
params.batch_norm))
return heads.RpnHead(
params.architecture.min_level,
params.architecture.max_level,
head_params.anchors_per_location,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def fast_rcnn_head_generator(params):
"""Generator function for Fast R-CNN head architecture."""
return heads.FastrcnnHead(params.num_classes,
params.num_convs,
params.num_filters,
params.use_separable_conv,
params.num_fcs,
params.fc_dims,
params.use_batch_norm,
batch_norm_relu=batch_norm_relu_generator(
params.batch_norm))
head_params = params.frcnn_head
return heads.FastrcnnHead(
params.architecture.num_classes,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
head_params.num_fcs,
head_params.fc_dims,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def mask_rcnn_head_generator(params):
"""Generator function for Mask R-CNN head architecture."""
return heads.MaskrcnnHead(params.num_classes,
params.mask_target_size,
params.num_convs,
params.num_filters,
params.use_separable_conv,
params.use_batch_norm,
batch_norm_relu=batch_norm_relu_generator(
params.batch_norm))
def shapeprior_head_generator(params):
"""Generator function for Shapemask head architecture."""
raise NotImplementedError('Unimplemented')
def coarsemask_head_generator(params):
"""Generator function for Shapemask head architecture."""
raise NotImplementedError('Unimplemented')
def finemask_head_generator(params):
"""Generator function for Shapemask head architecture."""
raise NotImplementedError('Unimplemented')
head_params = params.mrcnn_head
return heads.MaskrcnnHead(
params.architecture.num_classes,
params.architecture.mask_target_size,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
......@@ -41,8 +41,10 @@ class Fpn(object):
max_level=7,
fpn_feat_dims=256,
use_separable_conv=False,
activation='relu',
use_batch_norm=True,
batch_norm_relu=nn_ops.BatchNormRelu):
norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""FPN initialization function.
Args:
......@@ -52,8 +54,8 @@ class Fpn(object):
use_separable_conv: `bool`, if True use separable convolution for
convolution in FPN layers.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
batch_norm_relu: an operation that includes a batch normalization layer
followed by a relu layer(optional).
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
"""
self._min_level = min_level
self._max_level = max_level
......@@ -63,17 +65,23 @@ class Fpn(object):
tf.keras.layers.SeparableConv2D, depth_multiplier=1)
else:
self._conv2d_op = tf.keras.layers.Conv2D
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
self._batch_norm_relu = batch_norm_relu
self._norm_activation = norm_activation
self._batch_norm_relus = {}
self._norm_activations = {}
self._lateral_conv2d_op = {}
self._post_hoc_conv2d_op = {}
self._coarse_conv2d_op = {}
for level in range(self._min_level, self._max_level + 1):
if self._use_batch_norm:
self._batch_norm_relus[level] = batch_norm_relu(
relu=False, name='p%d-bn' % level)
self._norm_activations[level] = norm_activation(
use_activation=False, name='p%d-bn' % level)
self._lateral_conv2d_op[level] = self._conv2d_op(
filters=self._fpn_feat_dims,
kernel_size=(1, 1),
......@@ -133,11 +141,11 @@ class Fpn(object):
for level in range(backbone_max_level + 1, self._max_level + 1):
feats_in = feats[level - 1]
if level > backbone_max_level + 1:
feats_in = tf.nn.relu(feats_in)
feats_in = self._activation_op(feats_in)
feats[level] = self._coarse_conv2d_op[level](feats_in)
if self._use_batch_norm:
# Adds batch_norm layer.
for level in range(self._min_level, self._max_level + 1):
feats[level] = self._batch_norm_relus[level](
feats[level] = self._norm_activations[level](
feats[level], is_training=is_training)
return feats
......@@ -39,8 +39,10 @@ class RpnHead(tf.keras.layers.Layer):
num_convs=2,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_batch_norm=True,
batch_norm_relu=nn_ops.BatchNormRelu):
norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""Initialize params to build Region Proposal Network head.
Args:
......@@ -55,12 +57,18 @@ class RpnHead(tf.keras.layers.Layer):
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
batch_norm_relu: an operation that includes a batch normalization layer
followed by a relu layer(optional).
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
"""
self._min_level = min_level
self._max_level = max_level
self._anchors_per_location = anchors_per_location
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
if use_separable_conv:
......@@ -78,7 +86,7 @@ class RpnHead(tf.keras.layers.Layer):
num_filters,
kernel_size=(3, 3),
strides=(1, 1),
activation=(None if self._use_batch_norm else tf.nn.relu),
activation=(None if self._use_batch_norm else self._activation_op),
padding='same',
name='rpn')
self._rpn_class_conv = self._conv2d_op(
......@@ -94,10 +102,10 @@ class RpnHead(tf.keras.layers.Layer):
padding='valid',
name='rpn-box')
self._batch_norm_relus = {}
self._norm_activations = {}
if self._use_batch_norm:
for level in range(self._min_level, self._max_level + 1):
self._batch_norm_relus[level] = batch_norm_relu(name='rpn-l%d-bn' %
self._norm_activations[level] = norm_activation(name='rpn-l%d-bn' %
level)
def _shared_rpn_heads(self, features, anchors_per_location, level,
......@@ -106,7 +114,7 @@ class RpnHead(tf.keras.layers.Layer):
features = self._rpn_conv(features)
if self._use_batch_norm:
# The batch normalization layers are not shared between levels.
features = self._batch_norm_relus[level](
features = self._norm_activations[level](
features, is_training=is_training)
# Proposal classification scores
scores = self._rpn_class_conv(features)
......@@ -139,8 +147,10 @@ class FastrcnnHead(tf.keras.layers.Layer):
use_separable_conv=False,
num_fcs=2,
fc_dims=1024,
activation='relu',
use_batch_norm=True,
batch_norm_relu=nn_ops.BatchNormRelu):
norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""Initialize params to build Fast R-CNN box head.
Args:
......@@ -156,8 +166,8 @@ class FastrcnnHead(tf.keras.layers.Layer):
fc_dims: `int` number that represents the number of dimension of the FC
layers.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
batch_norm_relu: an operation that includes a batch normalization layer
followed by a relu layer(optional).
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
"""
self._num_classes = num_classes
......@@ -177,9 +187,14 @@ class FastrcnnHead(tf.keras.layers.Layer):
self._num_fcs = num_fcs
self._fc_dims = fc_dims
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
self._batch_norm_relu = batch_norm_relu
self._norm_activation = norm_activation
self._conv_ops = []
self._conv_bn_ops = []
......@@ -191,10 +206,10 @@ class FastrcnnHead(tf.keras.layers.Layer):
strides=(1, 1),
padding='same',
dilation_rate=(1, 1),
activation=(None if self._use_batch_norm else tf.nn.relu),
activation=(None if self._use_batch_norm else self._activation_op),
name='conv_{}'.format(i)))
if self._use_batch_norm:
self._conv_bn_ops.append(self._batch_norm_relu())
self._conv_bn_ops.append(self._norm_activation())
self._fc_ops = []
self._fc_bn_ops = []
......@@ -202,10 +217,10 @@ class FastrcnnHead(tf.keras.layers.Layer):
self._fc_ops.append(
tf.keras.layers.Dense(
units=self._fc_dims,
activation=(None if self._use_batch_norm else tf.nn.relu),
activation=(None if self._use_batch_norm else self._activation_op),
name='fc{}'.format(i)))
if self._use_batch_norm:
self._fc_bn_ops.append(self._batch_norm_relu(fused=False))
self._fc_bn_ops.append(self._norm_activation(fused=False))
self._class_predict = tf.keras.layers.Dense(
self._num_classes,
......@@ -266,8 +281,10 @@ class MaskrcnnHead(tf.keras.layers.Layer):
num_convs=4,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_batch_norm=True,
batch_norm_relu=nn_ops.BatchNormRelu):
norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""Initialize params to build Fast R-CNN head.
Args:
......@@ -280,8 +297,8 @@ class MaskrcnnHead(tf.keras.layers.Layer):
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
batch_norm_relu: an operation that includes a batch normalization layer
followed by a relu layer(optional).
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
"""
self._num_classes = num_classes
self._mask_target_size = mask_target_size
......@@ -299,9 +316,14 @@ class MaskrcnnHead(tf.keras.layers.Layer):
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer())
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
self._batch_norm_relu = batch_norm_relu
self._norm_activation = norm_activation
self._conv2d_ops = []
for i in range(self._num_convs):
self._conv2d_ops.append(
......@@ -311,14 +333,14 @@ class MaskrcnnHead(tf.keras.layers.Layer):
strides=(1, 1),
padding='same',
dilation_rate=(1, 1),
activation=(None if self._use_batch_norm else tf.nn.relu),
activation=(None if self._use_batch_norm else self._activation_op),
name='mask-conv-l%d' % i))
self._mask_conv_transpose = tf.keras.layers.Conv2DTranspose(
self._num_filters,
kernel_size=(2, 2),
strides=(2, 2),
padding='valid',
activation=(None if self._use_batch_norm else tf.nn.relu),
activation=(None if self._use_batch_norm else self._activation_op),
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer(),
......@@ -353,11 +375,11 @@ class MaskrcnnHead(tf.keras.layers.Layer):
for i in range(self._num_convs):
net = self._conv2d_ops[i](net)
if self._use_batch_norm:
net = self._batch_norm_relu()(net, is_training=is_training)
net = self._norm_activation()(net, is_training=is_training)
net = self._mask_conv_transpose(net)
if self._use_batch_norm:
net = self._batch_norm_relu()(net, is_training=is_training)
net = self._norm_activation()(net, is_training=is_training)
mask_outputs = self._conv2d_op(
self._num_classes,
......@@ -398,7 +420,8 @@ class RetinanetHead(object):
num_convs=4,
num_filters=256,
use_separable_conv=False,
batch_norm_relu=nn_ops.BatchNormRelu):
norm_activation=nn_ops.norm_activation_builder(
activation='relu')):
"""Initialize params to build RetinaNet head.
Args:
......@@ -411,8 +434,8 @@ class RetinanetHead(object):
num_filters: `int` number of filters used in the head architecture.
use_separable_conv: `bool` to indicate whether to use separable
convoluation.
batch_norm_relu: an operation that includes a batch normalization layer
followed by a relu layer(optional).
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
"""
self._min_level = min_level
self._max_level = max_level
......@@ -423,13 +446,12 @@ class RetinanetHead(object):
self._num_convs = num_convs
self._num_filters = num_filters
self._use_separable_conv = use_separable_conv
with tf.name_scope('class_net') as scope_name:
self._class_name_scope = tf.name_scope(scope_name)
with tf.name_scope('box_net') as scope_name:
self._box_name_scope = tf.name_scope(scope_name)
self._build_class_net_layers(batch_norm_relu)
self._build_box_net_layers(batch_norm_relu)
self._build_class_net_layers(norm_activation)
self._build_box_net_layers(norm_activation)
def _class_net_batch_norm_name(self, i, level):
return 'class-%d-%d' % (i, level)
......@@ -437,7 +459,7 @@ class RetinanetHead(object):
def _box_net_batch_norm_name(self, i, level):
return 'box-%d-%d' % (i, level)
def _build_class_net_layers(self, batch_norm_relu):
def _build_class_net_layers(self, norm_activation):
"""Build re-usable layers for class prediction network."""
if self._use_separable_conv:
self._class_predict = tf.keras.layers.SeparableConv2D(
......@@ -455,7 +477,7 @@ class RetinanetHead(object):
padding='same',
name='class-predict')
self._class_conv = []
self._class_batch_norm_relu = {}
self._class_norm_activation = {}
for i in range(self._num_convs):
if self._use_separable_conv:
self._class_conv.append(
......@@ -479,9 +501,9 @@ class RetinanetHead(object):
name='class-' + str(i)))
for level in range(self._min_level, self._max_level + 1):
name = self._class_net_batch_norm_name(i, level)
self._class_batch_norm_relu[name] = batch_norm_relu(name=name)
self._class_norm_activation[name] = norm_activation(name=name)
def _build_box_net_layers(self, batch_norm_relu):
def _build_box_net_layers(self, norm_activation):
"""Build re-usable layers for box prediction network."""
if self._use_separable_conv:
self._box_predict = tf.keras.layers.SeparableConv2D(
......@@ -499,7 +521,7 @@ class RetinanetHead(object):
padding='same',
name='box-predict')
self._box_conv = []
self._box_batch_norm_relu = {}
self._box_norm_activation = {}
for i in range(self._num_convs):
if self._use_separable_conv:
self._box_conv.append(
......@@ -523,13 +545,13 @@ class RetinanetHead(object):
name='box-' + str(i)))
for level in range(self._min_level, self._max_level + 1):
name = self._box_net_batch_norm_name(i, level)
self._box_batch_norm_relu[name] = batch_norm_relu(name=name)
self._box_norm_activation[name] = norm_activation(name=name)
def __call__(self, fpn_features, is_training=None):
"""Returns outputs of RetinaNet head."""
class_outputs = {}
box_outputs = {}
with backend.get_graph().as_default(), tf.name_scope('retinanet'):
with backend.get_graph().as_default(), tf.name_scope('retinanet_head'):
for level in range(self._min_level, self._max_level + 1):
features = fpn_features[level]
......@@ -548,7 +570,7 @@ class RetinanetHead(object):
# each level has its batch normlization to capture the statistical
# difference among different levels.
name = self._class_net_batch_norm_name(i, level)
features = self._class_batch_norm_relu[name](
features = self._class_norm_activation[name](
features, is_training=is_training)
classes = self._class_predict(features)
......@@ -563,7 +585,7 @@ class RetinanetHead(object):
# each level has its batch normlization to capture the statistical
# difference among different levels.
name = self._box_net_batch_norm_name(i, level)
features = self._box_batch_norm_relu[name](
features = self._box_norm_activation[name](
features, is_training=is_training)
boxes = self._box_predict(features)
......@@ -953,13 +975,13 @@ class ShapemaskCoarsemaskHead(object):
def coarsemask_decoder_net(self,
images,
is_training=None,
batch_norm_relu=nn_ops.BatchNormRelu):
norm_activation=nn_ops.norm_activation_builder()):
"""Coarse mask decoder network architecture.
Args:
images: A tensor of size [batch, height_in, width_in, channels_in].
is_training: Whether batch_norm layers are in training mode.
batch_norm_relu: an operation that includes a batch normalization layer
norm_activation: an operation that includes a batch normalization layer
followed by a relu layer(optional).
Returns:
images: A feature tensor of size [batch, output_size, output_size,
......@@ -975,7 +997,7 @@ class ShapemaskCoarsemaskHead(object):
padding='same',
name='coarse-class-%d' % i)(
images)
images = batch_norm_relu(name='coarse-class-%d-bn' % i)(
images = norm_activation(name='coarse-class-%d-bn' % i)(
images, is_training=is_training)
return images
......@@ -991,7 +1013,7 @@ class ShapemaskFinemaskHead(object):
num_convs,
coarse_mask_thr,
gt_upsample_scale,
batch_norm_relu=nn_ops.BatchNormRelu):
norm_activation=nn_ops.norm_activation_builder()):
"""Initialize params to build ShapeMask coarse and fine prediction head.
Args:
......@@ -1002,7 +1024,7 @@ class ShapemaskFinemaskHead(object):
layer.
coarse_mask_thr: the threshold for suppressing noisy coarse prediction.
gt_upsample_scale: scale for upsampling groundtruths.
batch_norm_relu: an operation that includes a batch normalization layer
norm_activation: an operation that includes a batch normalization layer
followed by a relu layer(optional).
"""
self._mask_num_classes = num_classes
......@@ -1038,7 +1060,7 @@ class ShapemaskFinemaskHead(object):
activation=None,
padding='same',
name='fine-class-%d' % i))
self._fine_class_bn.append(batch_norm_relu(name='fine-class-%d-bn' % i))
self._fine_class_bn.append(norm_activation(name='fine-class-%d-bn' % i))
def __call__(self, prior_conditioned_features, class_probs, is_training=None):
"""Generate instance masks from FPN features and detection priors.
......
......@@ -18,20 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import logging
import tensorflow.compat.v2 as tf
from tensorflow.python.keras import backend
class BatchNormRelu(tf.keras.layers.Layer):
"""Combined Batch Normalization and ReLU layers."""
class NormActivation(tf.keras.layers.Layer):
"""Combined Normalization and Activation layers."""
def __init__(self,
momentum=0.997,
epsilon=1e-4,
trainable=True,
relu=True,
init_zero=False,
use_activation=True,
activation='relu',
fused=True,
name=None):
"""A class to construct layers for a batch normalization followed by a ReLU.
......@@ -39,22 +40,24 @@ class BatchNormRelu(tf.keras.layers.Layer):
Args:
momentum: momentum for the moving average.
epsilon: small float added to variance to avoid dividing by zero.
trainable: `boolean`, if True also add variables to the graph collection
trainable: `bool`, if True also add variables to the graph collection
GraphKeys.TRAINABLE_VARIABLES. If False, freeze batch normalization
layer.
relu: `bool` if False, omits the ReLU operation.
init_zero: `bool` if True, initializes scale parameter of batch
normalization with 0. If False, initialize it with 1.
fused: `bool` fused option in batch normalziation.
use_actiation: `bool`, whether to add the optional activation layer after
the batch normalization layer.
activation: 'string', the type of the activation layer. Currently support
`relu` and `swish`.
name: `str` name for the operation.
"""
super(BatchNormRelu, self).__init__(trainable=trainable)
self._use_relu = relu
super(NormActivation, self).__init__(trainable=trainable)
if init_zero:
gamma_initializer = tf.keras.initializers.Zeros()
else:
gamma_initializer = tf.keras.initializers.Ones()
self._batch_norm_op = tf.keras.layers.BatchNormalization(
self._normalization_op = tf.keras.layers.BatchNormalization(
momentum=momentum,
epsilon=epsilon,
center=True,
......@@ -63,9 +66,16 @@ class BatchNormRelu(tf.keras.layers.Layer):
fused=fused,
gamma_initializer=gamma_initializer,
name=name)
self._use_activation = use_activation
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
def __call__(self, inputs, is_training=None):
"""Builds layers for a batch normalization followed by a ReLU.
"""Builds the normalization layer followed by an optional activation layer.
Args:
inputs: `Tensor` of shape `[batch, channels, ...]`.
......@@ -78,9 +88,22 @@ class BatchNormRelu(tf.keras.layers.Layer):
# from keras.Model.training
if is_training and self.trainable:
is_training = True
inputs = self._batch_norm_op(inputs, training=is_training)
inputs = self._normalization_op(inputs, training=is_training)
if self._use_relu:
inputs = tf.nn.relu(inputs)
if self._use_activation:
inputs = self._activation_op(inputs)
return inputs
def norm_activation_builder(momentum=0.997,
epsilon=1e-4,
trainable=True,
activation='relu',
**kwargs):
return functools.partial(
NormActivation,
momentum=momentum,
epsilon=epsilon,
trainable=trainable,
activation='relu',
**kwargs)
......@@ -34,21 +34,27 @@ class Resnet(object):
def __init__(self,
resnet_depth,
batch_norm_relu=nn_ops.BatchNormRelu,
activation='relu',
norm_activation=nn_ops.norm_activation_builder(
activation='relu'),
data_format='channels_last'):
"""ResNet initialization function.
Args:
resnet_depth: `int` depth of ResNet backbone model.
batch_norm_relu: an operation that includes a batch normalization layer
followed by a relu layer(optional).
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
data_format: `str` either "channels_first" for `[batch, channels, height,
width]` or "channels_last for `[batch, height, width, channels]`.
"""
self._resnet_depth = resnet_depth
self._batch_norm_relu = batch_norm_relu
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._norm_activation = norm_activation
self._data_format = data_format
model_params = {
......@@ -170,19 +176,19 @@ class Resnet(object):
# Projection shortcut in first layer to match filters and strides
shortcut = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=strides)
shortcut = self._batch_norm_relu(relu=False)(
shortcut = self._norm_activation(use_activation=False)(
shortcut, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides)
inputs = self._batch_norm_relu()(inputs, is_training=is_training)
inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=1)
inputs = self._batch_norm_relu()(
inputs, relu=False, init_zero=True, is_training=is_training)
inputs = self._norm_activation(use_activation=False, init_zero=True)(
inputs, is_training=is_training)
return tf.nn.relu(inputs + shortcut)
return self._activation_op(inputs + shortcut)
def bottleneck_block(self,
inputs,
......@@ -214,24 +220,23 @@ class Resnet(object):
filters_out = 4 * filters
shortcut = self.conv2d_fixed_padding(
inputs=inputs, filters=filters_out, kernel_size=1, strides=strides)
shortcut = self._batch_norm_relu(relu=False)(
shortcut = self._norm_activation(use_activation=False)(
shortcut, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=1)
inputs = self._batch_norm_relu()(inputs, is_training=is_training)
inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides)
inputs = self._batch_norm_relu()(inputs, is_training=is_training)
inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=4 * filters, kernel_size=1, strides=1)
inputs = self._batch_norm_relu(
relu=False, init_zero=True)(
inputs = self._norm_activation(use_activation=False, init_zero=True)(
inputs, is_training=is_training)
return tf.nn.relu(inputs + shortcut)
return self._activation_op(inputs + shortcut)
def block_group(self, inputs, filters, block_fn, blocks, strides, name,
is_training):
......@@ -279,7 +284,7 @@ class Resnet(object):
inputs = self.conv2d_fixed_padding(
inputs=inputs, filters=64, kernel_size=7, strides=2)
inputs = tf.identity(inputs, 'initial_conv')
inputs = self._batch_norm_relu()(inputs, is_training=is_training)
inputs = self._norm_activation()(inputs, is_training=is_training)
inputs = tf.keras.layers.MaxPool2D(
pool_size=3, strides=2, padding='SAME',
......
......@@ -43,7 +43,8 @@ def _make_filter_trainable_variables_fn(frozen_variable_prefix):
# the frozen variables' names.
filtered_variables = [
v for v in variables
if not re.match(frozen_variable_prefix, v.name)
if not frozen_variable_prefix or
not re.match(frozen_variable_prefix, v.name)
]
return filtered_variables
......@@ -66,7 +67,7 @@ class Model(object):
# Optimization.
self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer)
self._learning_rate = learning_rates.learning_rate_generator(
params.train.learning_rate)
params.train.total_steps, params.train.learning_rate)
self._frozen_variable_prefix = params.train.frozen_variable_prefix
self._regularization_var_regex = params.train.regularization_variable_regex
......
......@@ -28,9 +28,10 @@ from official.modeling.hyperparams import params_dict
class StepLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor."""
def __init__(self, params):
def __init__(self, total_steps, params):
"""Creates the step learning rate tensor with linear warmup."""
super(StepLearningRateWithLinearWarmup, self).__init__()
self._total_steps = total_steps
assert isinstance(params, (dict, params_dict.ParamsDict))
if isinstance(params, dict):
params = params_dict.ParamsDict(params)
......@@ -59,9 +60,10 @@ class StepLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRat
class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor."""
def __init__(self, params):
def __init__(self, total_steps, params):
"""Creates the consine learning rate tensor with linear warmup."""
super(CosineLearningRateWithLinearWarmup, self).__init__()
self._total_steps = total_steps
assert isinstance(params, (dict, params_dict.ParamsDict))
if isinstance(params, dict):
params = params_dict.ParamsDict(params)
......@@ -72,7 +74,7 @@ class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningR
warmup_lr = self._params.warmup_learning_rate
warmup_steps = self._params.warmup_steps
init_lr = self._params.init_learning_rate
total_steps = self._params.total_steps
total_steps = self._total_steps
linear_warmup = (
warmup_lr + global_step / warmup_steps * (init_lr - warmup_lr))
cosine_learning_rate = (
......@@ -86,11 +88,11 @@ class CosineLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningR
return {'_params': self._params.as_dict()}
def learning_rate_generator(params):
def learning_rate_generator(total_steps, params):
"""The learning rate function generator."""
if params.type == 'step':
return StepLearningRateWithLinearWarmup(params)
return StepLearningRateWithLinearWarmup(total_steps, params)
elif params.type == 'cosine':
return CosineLearningRateWithLinearWarmup(params)
return CosineLearningRateWithLinearWarmup(total_steps, params)
else:
raise ValueError('Unsupported learning rate type: {}.'.format(params.type))
......@@ -371,8 +371,8 @@ class MaskrcnnLoss(object):
class RetinanetClassLoss(object):
"""RetinaNet class loss."""
def __init__(self, params):
self._num_classes = params.num_classes
def __init__(self, params, num_classes):
self._num_classes = num_classes
self._focal_loss_alpha = params.focal_loss_alpha
self._focal_loss_gamma = params.focal_loss_gamma
......
......@@ -49,14 +49,16 @@ class MaskrcnnModel(base_model.Model):
# Architecture generators.
self._backbone_fn = factory.backbone_generator(params)
self._fpn_fn = factory.multilevel_features_generator(params)
self._rpn_head_fn = factory.rpn_head_generator(params.rpn_head)
self._rpn_head_fn = factory.rpn_head_generator(params)
self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
self._sample_rois_fn = sampling_ops.ROISampler(params.roi_sampling)
self._sample_masks_fn = sampling_ops.MaskSampler(params.mask_sampling)
self._sample_masks_fn = sampling_ops.MaskSampler(
params.architecture.mask_target_size,
params.mask_sampling.num_mask_samples_per_image)
self._frcnn_head_fn = factory.fast_rcnn_head_generator(params.frcnn_head)
self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
if self._include_mask:
self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params.mrcnn_head)
self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)
# Loss function.
self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
......@@ -91,8 +93,8 @@ class MaskrcnnModel(base_model.Model):
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_box_outputs),
})
input_anchor = anchor.Anchor(self._params.anchor.min_level,
self._params.anchor.max_level,
input_anchor = anchor.Anchor(self._params.architecture.min_level,
self._params.architecture.max_level,
self._params.anchor.num_scales,
self._params.anchor.aspect_ratios,
self._params.anchor.anchor_size,
......
......@@ -30,15 +30,10 @@ class OptimizerFactory(object):
def __init__(self, params):
"""Creates optimized based on the specified flags."""
if params.type == 'momentum':
nesterov = False
try:
nesterov = params.nesterov
except AttributeError:
pass
self._optimizer = functools.partial(
tf.keras.optimizers.SGD,
momentum=params.momentum,
nesterov=nesterov)
nesterov=params.nesterov)
elif params.type == 'adam':
self._optimizer = tf.keras.optimizers.Adam
elif params.type == 'adadelta':
......
......@@ -44,16 +44,19 @@ class RetinanetModel(base_model.Model):
# Architecture generators.
self._backbone_fn = factory.backbone_generator(params)
self._fpn_fn = factory.multilevel_features_generator(params)
self._head_fn = factory.retinanet_head_generator(params.retinanet_head)
self._head_fn = factory.retinanet_head_generator(params)
# Loss function.
self._cls_loss_fn = losses.RetinanetClassLoss(params.retinanet_loss)
self._cls_loss_fn = losses.RetinanetClassLoss(
params.retinanet_loss, params.architecture.num_classes)
self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
self._box_loss_weight = params.retinanet_loss.box_loss_weight
self._keras_model = None
# Predict function.
self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
params.architecture.min_level,
params.architecture.max_level,
params.postprocess)
self._transpose_input = params.train.transpose_input
......
......@@ -294,10 +294,10 @@ def _generate_detections_batched(boxes,
class MultilevelDetectionGenerator(object):
"""Generates detected boxes with scores and classes for one-stage detector."""
def __init__(self, params):
def __init__(self, min_level, max_level, params):
self._min_level = min_level
self._max_level = max_level
self._generate_detections = generate_detections_factory(params)
self._min_level = params.min_level
self._max_level = params.max_level
def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
# Collects outputs from all levels into a list.
......
......@@ -346,9 +346,9 @@ class ROISampler(object):
class MaskSampler(object):
"""Samples and creates mask training targets."""
def __init__(self, params):
self._num_mask_samples_per_image = params.num_mask_samples_per_image
self._mask_target_size = params.mask_target_size
def __init__(self, mask_target_size, num_mask_samples_per_image):
self._mask_target_size = mask_target_size
self._num_mask_samples_per_image = num_mask_samples_per_image
def __call__(self,
candidate_rois,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment