Commit 3956d90e authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Use mobilenet_v1_arg_scope to pass the weight_decay and batch norm params.

PiperOrigin-RevId: 190682119
parent d724a08b
......@@ -22,30 +22,6 @@ from nets import mobilenet_v1
slim = tf.contrib.slim
def _batch_norm_arg_scope(list_ops,
use_batch_norm=True,
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001,
batch_norm_scale=False,
train_batch_norm=False):
"""Slim arg scope for Mobilenet V1 batch norm."""
if use_batch_norm:
batch_norm_params = {
'is_training': train_batch_norm,
'scale': batch_norm_scale,
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon
}
normalizer_fn = slim.batch_norm
else:
normalizer_fn = None
batch_norm_params = None
return slim.arg_scope(list_ops,
normalizer_fn=normalizer_fn,
normalizer_params=batch_norm_params)
class FasterRCNNMobilenetV1FeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
"""Faster R-CNN Mobilenet V1 feature extractor implementation."""
......@@ -121,18 +97,19 @@ class FasterRCNNMobilenetV1FeatureExtractor(
['image size must at least be 33 in both height and width.'])
with tf.control_dependencies([shape_assert]):
with tf.variable_scope('MobilenetV1',
reuse=self._reuse_weights) as scope:
with _batch_norm_arg_scope([slim.conv2d, slim.separable_conv2d],
batch_norm_scale=True,
train_batch_norm=self._train_batch_norm):
with slim.arg_scope(
mobilenet_v1.mobilenet_v1_arg_scope(
is_training=self._train_batch_norm,
weight_decay=self._weight_decay)):
with tf.variable_scope('MobilenetV1',
reuse=self._reuse_weights) as scope:
_, activations = mobilenet_v1.mobilenet_v1_base(
preprocessed_inputs,
final_endpoint='Conv2d_13_pointwise',
final_endpoint='Conv2d_11_pointwise',
min_depth=self._min_depth,
depth_multiplier=self._depth_multiplier,
scope=scope)
return activations['Conv2d_13_pointwise'], activations
return activations['Conv2d_11_pointwise'], activations
def _extract_box_classifier_features(self, proposal_feature_maps, scope):
"""Extracts second stage box classifier features.
......@@ -152,9 +129,10 @@ class FasterRCNNMobilenetV1FeatureExtractor(
depth = lambda d: max(int(d * 1.0), 16)
with tf.variable_scope('MobilenetV1', reuse=self._reuse_weights):
with _batch_norm_arg_scope([slim.conv2d, slim.separable_conv2d],
batch_norm_scale=True,
train_batch_norm=self._train_batch_norm):
with slim.arg_scope(
mobilenet_v1.mobilenet_v1_arg_scope(
is_training=self._train_batch_norm,
weight_decay=self._weight_decay)):
with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d], padding='SAME'):
net = slim.separable_conv2d(
......
......@@ -44,7 +44,7 @@ class FasterRcnnMobilenetV1FeatureExtractorTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [4, 7, 7, 1024])
self.assertAllEqual(features_shape_out, [4, 14, 14, 512])
def test_extract_proposal_features_stride_eight(self):
feature_extractor = self._build_feature_extractor(
......@@ -59,7 +59,7 @@ class FasterRcnnMobilenetV1FeatureExtractorTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [4, 7, 7, 1024])
self.assertAllEqual(features_shape_out, [4, 14, 14, 512])
def test_extract_proposal_features_half_size_input(self):
feature_extractor = self._build_feature_extractor(
......@@ -74,7 +74,7 @@ class FasterRcnnMobilenetV1FeatureExtractorTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [1, 4, 4, 1024])
self.assertAllEqual(features_shape_out, [1, 7, 7, 512])
def test_extract_proposal_features_dies_on_invalid_stride(self):
with self.assertRaises(ValueError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment