"sgl-router/vscode:/vscode.git/clone" did not exist on "08b8c0c3cd636e1e73bba5e667a6aa90df804546"
Unverified Commit fe748d4a authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Object detection changes: (#7208)

257914648  by lzc:

    Internal changes

--
257525973  by Zhichao Lu:

    Fixes bug that silently prevents checkpoints from loading when training w/ eager + functions. Also sets up scripts to run training.

--
257296614  by Zhichao Lu:

    Adding detection_features to model outputs

--
257234565  by Zhichao Lu:

    Fix wrong order of `classes_with_max_scores` in class-agnostic NMS caused by
    sorting in partitioned-NMS.

--
257232002  by ronnyvotel:

    Supporting `filter_nonoverlapping` option in np_box_list_ops.clip_to_window().

--
257198282  by Zhichao Lu:

    Adding the focal loss and l1 loss from the Objects as Points paper.

--
257089535  by Zhichao Lu:

    Create Keras based ssd + resnetv1 + fpn.

--
257087407  by Zhichao Lu:

    Make object_detection/data_decoders Python3-compatible.

--
257004582  by Zhichao Lu:

    Updates _decode_raw_data_into_masks_and_boxes to the latest binary masks-to-string encoding fo...
parent 81123ebf
......@@ -21,6 +21,7 @@ Based on PNASNet model: https://arxiv.org/abs/1712.00559
import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import variables_helper
from nets.nasnet import nasnet_utils
from nets.nasnet import pnasnet
......@@ -302,7 +303,7 @@ class FasterRCNNPNASFeatureExtractor(
the model graph.
"""
variables_to_restore = {}
for variable in tf.global_variables():
for variable in variables_helper.get_global_variables_safely():
if variable.op.name.startswith(
first_stage_feature_extractor_scope):
var_name = variable.op.name.replace(
......
......@@ -44,7 +44,8 @@ class FasterRCNNResnetV1FeatureExtractor(
first_stage_features_stride,
batch_norm_trainable=False,
reuse_weights=None,
weight_decay=0.0):
weight_decay=0.0,
activation_fn=tf.nn.relu):
"""Constructor.
Args:
......@@ -55,6 +56,7 @@ class FasterRCNNResnetV1FeatureExtractor(
batch_norm_trainable: See base class.
reuse_weights: See base class.
weight_decay: See base class.
activation_fn: Activaton functon to use in Resnet V1 model.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16.
......@@ -63,9 +65,10 @@ class FasterRCNNResnetV1FeatureExtractor(
raise ValueError('`first_stage_features_stride` must be 8 or 16.')
self._architecture = architecture
self._resnet_model = resnet_model
super(FasterRCNNResnetV1FeatureExtractor, self).__init__(
is_training, first_stage_features_stride, batch_norm_trainable,
reuse_weights, weight_decay)
self._activation_fn = activation_fn
super(FasterRCNNResnetV1FeatureExtractor,
self).__init__(is_training, first_stage_features_stride,
batch_norm_trainable, reuse_weights, weight_decay)
def preprocess(self, resized_inputs):
"""Faster R-CNN Resnet V1 preprocessing.
......@@ -125,6 +128,7 @@ class FasterRCNNResnetV1FeatureExtractor(
resnet_utils.resnet_arg_scope(
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
activation_fn=self._activation_fn,
weight_decay=self._weight_decay)):
with tf.variable_scope(
self._architecture, reuse=self._reuse_weights) as var_scope:
......@@ -159,6 +163,7 @@ class FasterRCNNResnetV1FeatureExtractor(
resnet_utils.resnet_arg_scope(
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
activation_fn=self._activation_fn,
weight_decay=self._weight_decay)):
with slim.arg_scope([slim.batch_norm],
is_training=self._train_batch_norm):
......@@ -182,7 +187,8 @@ class FasterRCNNResnet50FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
first_stage_features_stride,
batch_norm_trainable=False,
reuse_weights=None,
weight_decay=0.0):
weight_decay=0.0,
activation_fn=tf.nn.relu):
"""Constructor.
Args:
......@@ -191,15 +197,16 @@ class FasterRCNNResnet50FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
batch_norm_trainable: See base class.
reuse_weights: See base class.
weight_decay: See base class.
activation_fn: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16,
or if `architecture` is not supported.
"""
super(FasterRCNNResnet50FeatureExtractor, self).__init__(
'resnet_v1_50', resnet_v1.resnet_v1_50, is_training,
super(FasterRCNNResnet50FeatureExtractor,
self).__init__('resnet_v1_50', resnet_v1.resnet_v1_50, is_training,
first_stage_features_stride, batch_norm_trainable,
reuse_weights, weight_decay)
reuse_weights, weight_decay, activation_fn)
class FasterRCNNResnet101FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
......@@ -210,7 +217,8 @@ class FasterRCNNResnet101FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
first_stage_features_stride,
batch_norm_trainable=False,
reuse_weights=None,
weight_decay=0.0):
weight_decay=0.0,
activation_fn=tf.nn.relu):
"""Constructor.
Args:
......@@ -219,15 +227,16 @@ class FasterRCNNResnet101FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
batch_norm_trainable: See base class.
reuse_weights: See base class.
weight_decay: See base class.
activation_fn: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16,
or if `architecture` is not supported.
"""
super(FasterRCNNResnet101FeatureExtractor, self).__init__(
'resnet_v1_101', resnet_v1.resnet_v1_101, is_training,
super(FasterRCNNResnet101FeatureExtractor,
self).__init__('resnet_v1_101', resnet_v1.resnet_v1_101, is_training,
first_stage_features_stride, batch_norm_trainable,
reuse_weights, weight_decay)
reuse_weights, weight_decay, activation_fn)
class FasterRCNNResnet152FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
......@@ -238,7 +247,8 @@ class FasterRCNNResnet152FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
first_stage_features_stride,
batch_norm_trainable=False,
reuse_weights=None,
weight_decay=0.0):
weight_decay=0.0,
activation_fn=tf.nn.relu):
"""Constructor.
Args:
......@@ -247,12 +257,13 @@ class FasterRCNNResnet152FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
batch_norm_trainable: See base class.
reuse_weights: See base class.
weight_decay: See base class.
activation_fn: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16,
or if `architecture` is not supported.
"""
super(FasterRCNNResnet152FeatureExtractor, self).__init__(
'resnet_v1_152', resnet_v1.resnet_v1_152, is_training,
super(FasterRCNNResnet152FeatureExtractor,
self).__init__('resnet_v1_152', resnet_v1.resnet_v1_152, is_training,
first_stage_features_stride, batch_norm_trainable,
reuse_weights, weight_decay)
reuse_weights, weight_decay, activation_fn)
......@@ -79,14 +79,19 @@ def create_conv_block(
"""
layers = []
if use_depthwise:
layers.append(tf.keras.layers.SeparableConv2D(
depth,
[kernel_size, kernel_size],
kwargs = conv_hyperparams.params()
# Both the regularizer and initializer apply to the depthwise layer,
# so we remap the kernel_* to depthwise_* here.
kwargs['depthwise_regularizer'] = kwargs['kernel_regularizer']
kwargs['depthwise_initializer'] = kwargs['kernel_initializer']
layers.append(
tf.keras.layers.SeparableConv2D(
depth, [kernel_size, kernel_size],
depth_multiplier=1,
padding=padding,
strides=stride,
name=layer_name + '_depthwise_conv',
**conv_hyperparams.params()))
**kwargs))
else:
layers.append(tf.keras.layers.Conv2D(
depth,
......
......@@ -160,7 +160,12 @@ class _LayersOverride(object):
"""
if self._conv_hyperparams:
kwargs = self._conv_hyperparams.params(**kwargs)
# Both the regularizer and initializer apply to the depthwise layer in
# MobilenetV1, so we remap the kernel_* to depthwise_* here.
kwargs['depthwise_regularizer'] = kwargs['kernel_regularizer']
kwargs['depthwise_initializer'] = kwargs['kernel_initializer']
else:
kwargs['depthwise_regularizer'] = self.regularizer
kwargs['depthwise_initializer'] = self.initializer
kwargs['padding'] = 'same'
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment