Commit aedfa2e4 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Add an option to skip the last striding in mobilenet. The modified network has...

Add an option to skip the last striding in mobilenet. The modified network has nominal output stride 16 instead of 32.

PiperOrigin-RevId: 191932855
parent decbad8a
......@@ -22,6 +22,24 @@ from nets import mobilenet_v1
slim = tf.contrib.slim
_MOBILENET_V1_100_CONV_NO_LAST_STRIDE_DEFS = [
mobilenet_v1.Conv(kernel=[3, 3], stride=2, depth=32),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=64),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=2, depth=128),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=128),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=2, depth=256),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=256),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=2, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=1024),
mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=1024)
]
class FasterRCNNMobilenetV1FeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
"""Faster R-CNN Mobilenet V1 feature extractor implementation."""
......@@ -33,7 +51,8 @@ class FasterRCNNMobilenetV1FeatureExtractor(
reuse_weights=None,
weight_decay=0.0,
depth_multiplier=1.0,
min_depth=16):
min_depth=16,
skip_last_stride=False):
"""Constructor.
Args:
......@@ -44,6 +63,7 @@ class FasterRCNNMobilenetV1FeatureExtractor(
weight_decay: See base class.
depth_multiplier: float depth multiplier for feature extractor.
min_depth: minimum feature extractor depth.
skip_last_stride: Skip the last stride if True.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16.
......@@ -52,6 +72,7 @@ class FasterRCNNMobilenetV1FeatureExtractor(
raise ValueError('`first_stage_features_stride` must be 8 or 16.')
self._depth_multiplier = depth_multiplier
self._min_depth = min_depth
self._skip_last_stride = skip_last_stride
super(FasterRCNNMobilenetV1FeatureExtractor, self).__init__(
is_training, first_stage_features_stride, batch_norm_trainable,
reuse_weights, weight_decay)
......@@ -103,12 +124,16 @@ class FasterRCNNMobilenetV1FeatureExtractor(
weight_decay=self._weight_decay)):
with tf.variable_scope('MobilenetV1',
reuse=self._reuse_weights) as scope:
params = {}
if self._skip_last_stride:
params['conv_defs'] = _MOBILENET_V1_100_CONV_NO_LAST_STRIDE_DEFS
_, activations = mobilenet_v1.mobilenet_v1_base(
preprocessed_inputs,
final_endpoint='Conv2d_11_pointwise',
min_depth=self._min_depth,
depth_multiplier=self._depth_multiplier,
scope=scope)
scope=scope,
**params)
return activations['Conv2d_11_pointwise'], activations
def _extract_box_classifier_features(self, proposal_feature_maps, scope):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment