Commit 3956d90e authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Use mobilenet_v1_arg_scope to pass the weight_decay and batch norm params.

PiperOrigin-RevId: 190682119
parent d724a08b
...@@ -22,30 +22,6 @@ from nets import mobilenet_v1 ...@@ -22,30 +22,6 @@ from nets import mobilenet_v1
slim = tf.contrib.slim slim = tf.contrib.slim
def _batch_norm_arg_scope(list_ops,
use_batch_norm=True,
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001,
batch_norm_scale=False,
train_batch_norm=False):
"""Slim arg scope for Mobilenet V1 batch norm."""
if use_batch_norm:
batch_norm_params = {
'is_training': train_batch_norm,
'scale': batch_norm_scale,
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon
}
normalizer_fn = slim.batch_norm
else:
normalizer_fn = None
batch_norm_params = None
return slim.arg_scope(list_ops,
normalizer_fn=normalizer_fn,
normalizer_params=batch_norm_params)
class FasterRCNNMobilenetV1FeatureExtractor( class FasterRCNNMobilenetV1FeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor): faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
"""Faster R-CNN Mobilenet V1 feature extractor implementation.""" """Faster R-CNN Mobilenet V1 feature extractor implementation."""
...@@ -121,18 +97,19 @@ class FasterRCNNMobilenetV1FeatureExtractor( ...@@ -121,18 +97,19 @@ class FasterRCNNMobilenetV1FeatureExtractor(
['image size must at least be 33 in both height and width.']) ['image size must at least be 33 in both height and width.'])
with tf.control_dependencies([shape_assert]): with tf.control_dependencies([shape_assert]):
with tf.variable_scope('MobilenetV1', with slim.arg_scope(
reuse=self._reuse_weights) as scope: mobilenet_v1.mobilenet_v1_arg_scope(
with _batch_norm_arg_scope([slim.conv2d, slim.separable_conv2d], is_training=self._train_batch_norm,
batch_norm_scale=True, weight_decay=self._weight_decay)):
train_batch_norm=self._train_batch_norm): with tf.variable_scope('MobilenetV1',
reuse=self._reuse_weights) as scope:
_, activations = mobilenet_v1.mobilenet_v1_base( _, activations = mobilenet_v1.mobilenet_v1_base(
preprocessed_inputs, preprocessed_inputs,
final_endpoint='Conv2d_13_pointwise', final_endpoint='Conv2d_11_pointwise',
min_depth=self._min_depth, min_depth=self._min_depth,
depth_multiplier=self._depth_multiplier, depth_multiplier=self._depth_multiplier,
scope=scope) scope=scope)
return activations['Conv2d_13_pointwise'], activations return activations['Conv2d_11_pointwise'], activations
def _extract_box_classifier_features(self, proposal_feature_maps, scope): def _extract_box_classifier_features(self, proposal_feature_maps, scope):
"""Extracts second stage box classifier features. """Extracts second stage box classifier features.
...@@ -152,9 +129,10 @@ class FasterRCNNMobilenetV1FeatureExtractor( ...@@ -152,9 +129,10 @@ class FasterRCNNMobilenetV1FeatureExtractor(
depth = lambda d: max(int(d * 1.0), 16) depth = lambda d: max(int(d * 1.0), 16)
with tf.variable_scope('MobilenetV1', reuse=self._reuse_weights): with tf.variable_scope('MobilenetV1', reuse=self._reuse_weights):
with _batch_norm_arg_scope([slim.conv2d, slim.separable_conv2d], with slim.arg_scope(
batch_norm_scale=True, mobilenet_v1.mobilenet_v1_arg_scope(
train_batch_norm=self._train_batch_norm): is_training=self._train_batch_norm,
weight_decay=self._weight_decay)):
with slim.arg_scope( with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d], padding='SAME'): [slim.conv2d, slim.separable_conv2d], padding='SAME'):
net = slim.separable_conv2d( net = slim.separable_conv2d(
......
...@@ -44,7 +44,7 @@ class FasterRcnnMobilenetV1FeatureExtractorTest(tf.test.TestCase): ...@@ -44,7 +44,7 @@ class FasterRcnnMobilenetV1FeatureExtractorTest(tf.test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
sess.run(init_op) sess.run(init_op)
features_shape_out = sess.run(features_shape) features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [4, 7, 7, 1024]) self.assertAllEqual(features_shape_out, [4, 14, 14, 512])
def test_extract_proposal_features_stride_eight(self): def test_extract_proposal_features_stride_eight(self):
feature_extractor = self._build_feature_extractor( feature_extractor = self._build_feature_extractor(
...@@ -59,7 +59,7 @@ class FasterRcnnMobilenetV1FeatureExtractorTest(tf.test.TestCase): ...@@ -59,7 +59,7 @@ class FasterRcnnMobilenetV1FeatureExtractorTest(tf.test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
sess.run(init_op) sess.run(init_op)
features_shape_out = sess.run(features_shape) features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [4, 7, 7, 1024]) self.assertAllEqual(features_shape_out, [4, 14, 14, 512])
def test_extract_proposal_features_half_size_input(self): def test_extract_proposal_features_half_size_input(self):
feature_extractor = self._build_feature_extractor( feature_extractor = self._build_feature_extractor(
...@@ -74,7 +74,7 @@ class FasterRcnnMobilenetV1FeatureExtractorTest(tf.test.TestCase): ...@@ -74,7 +74,7 @@ class FasterRcnnMobilenetV1FeatureExtractorTest(tf.test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
sess.run(init_op) sess.run(init_op)
features_shape_out = sess.run(features_shape) features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [1, 4, 4, 1024]) self.assertAllEqual(features_shape_out, [1, 7, 7, 512])
def test_extract_proposal_features_dies_on_invalid_stride(self): def test_extract_proposal_features_dies_on_invalid_stride(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment