Commit aedfa2e4 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Add an option to skip the last striding in mobilenet. The modified network has...

Add an option to skip the last striding in mobilenet. The modified network has nominal output stride 16 instead of 32.

PiperOrigin-RevId: 191932855
parent decbad8a
...@@ -22,6 +22,24 @@ from nets import mobilenet_v1 ...@@ -22,6 +22,24 @@ from nets import mobilenet_v1
slim = tf.contrib.slim slim = tf.contrib.slim
_MOBILENET_V1_100_CONV_NO_LAST_STRIDE_DEFS = [
mobilenet_v1.Conv(kernel=[3, 3], stride=2, depth=32),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=64),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=2, depth=128),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=128),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=2, depth=256),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=256),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=2, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=1024),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=1024)
]
class FasterRCNNMobilenetV1FeatureExtractor( class FasterRCNNMobilenetV1FeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor): faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
"""Faster R-CNN Mobilenet V1 feature extractor implementation.""" """Faster R-CNN Mobilenet V1 feature extractor implementation."""
...@@ -33,7 +51,8 @@ class FasterRCNNMobilenetV1FeatureExtractor( ...@@ -33,7 +51,8 @@ class FasterRCNNMobilenetV1FeatureExtractor(
reuse_weights=None, reuse_weights=None,
weight_decay=0.0, weight_decay=0.0,
depth_multiplier=1.0, depth_multiplier=1.0,
min_depth=16): min_depth=16,
skip_last_stride=False):
"""Constructor. """Constructor.
Args: Args:
...@@ -44,6 +63,7 @@ class FasterRCNNMobilenetV1FeatureExtractor( ...@@ -44,6 +63,7 @@ class FasterRCNNMobilenetV1FeatureExtractor(
weight_decay: See base class. weight_decay: See base class.
depth_multiplier: float depth multiplier for feature extractor. depth_multiplier: float depth multiplier for feature extractor.
min_depth: minimum feature extractor depth. min_depth: minimum feature extractor depth.
skip_last_stride: Skip the last stride if True.
Raises: Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16. ValueError: If `first_stage_features_stride` is not 8 or 16.
...@@ -52,6 +72,7 @@ class FasterRCNNMobilenetV1FeatureExtractor( ...@@ -52,6 +72,7 @@ class FasterRCNNMobilenetV1FeatureExtractor(
raise ValueError('`first_stage_features_stride` must be 8 or 16.') raise ValueError('`first_stage_features_stride` must be 8 or 16.')
self._depth_multiplier = depth_multiplier self._depth_multiplier = depth_multiplier
self._min_depth = min_depth self._min_depth = min_depth
self._skip_last_stride = skip_last_stride
super(FasterRCNNMobilenetV1FeatureExtractor, self).__init__( super(FasterRCNNMobilenetV1FeatureExtractor, self).__init__(
is_training, first_stage_features_stride, batch_norm_trainable, is_training, first_stage_features_stride, batch_norm_trainable,
reuse_weights, weight_decay) reuse_weights, weight_decay)
...@@ -103,12 +124,16 @@ class FasterRCNNMobilenetV1FeatureExtractor( ...@@ -103,12 +124,16 @@ class FasterRCNNMobilenetV1FeatureExtractor(
weight_decay=self._weight_decay)): weight_decay=self._weight_decay)):
with tf.variable_scope('MobilenetV1', with tf.variable_scope('MobilenetV1',
reuse=self._reuse_weights) as scope: reuse=self._reuse_weights) as scope:
params = {}
if self._skip_last_stride:
params['conv_defs'] = _MOBILENET_V1_100_CONV_NO_LAST_STRIDE_DEFS
_, activations = mobilenet_v1.mobilenet_v1_base( _, activations = mobilenet_v1.mobilenet_v1_base(
preprocessed_inputs, preprocessed_inputs,
final_endpoint='Conv2d_11_pointwise', final_endpoint='Conv2d_11_pointwise',
min_depth=self._min_depth, min_depth=self._min_depth,
depth_multiplier=self._depth_multiplier, depth_multiplier=self._depth_multiplier,
scope=scope) scope=scope,
**params)
return activations['Conv2d_11_pointwise'], activations return activations['Conv2d_11_pointwise'], activations
def _extract_box_classifier_features(self, proposal_feature_maps, scope): def _extract_box_classifier_features(self, proposal_feature_maps, scope):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment