Commit e25c014e authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

use shape utils for assertion in feature extractor.

PiperOrigin-RevId: 192147130
parent a4d9c3a0
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import tensorflow as tf import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import shape_utils
from nets import mobilenet_v1 from nets import mobilenet_v1
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -112,28 +113,25 @@ class FasterRCNNMobilenetV1FeatureExtractor( ...@@ -112,28 +113,25 @@ class FasterRCNNMobilenetV1FeatureExtractor(
""" """
preprocessed_inputs.get_shape().assert_has_rank(4) preprocessed_inputs.get_shape().assert_has_rank(4)
shape_assert = tf.Assert( preprocessed_inputs = shape_utils.check_min_image_dim(
tf.logical_and(tf.greater_equal(tf.shape(preprocessed_inputs)[1], 33), min_dim=33, image_tensor=preprocessed_inputs)
tf.greater_equal(tf.shape(preprocessed_inputs)[2], 33)),
['image size must at least be 33 in both height and width.']) with slim.arg_scope(
mobilenet_v1.mobilenet_v1_arg_scope(
with tf.control_dependencies([shape_assert]): is_training=self._train_batch_norm,
with slim.arg_scope( weight_decay=self._weight_decay)):
mobilenet_v1.mobilenet_v1_arg_scope( with tf.variable_scope('MobilenetV1',
is_training=self._train_batch_norm, reuse=self._reuse_weights) as scope:
weight_decay=self._weight_decay)): params = {}
with tf.variable_scope('MobilenetV1', if self._skip_last_stride:
reuse=self._reuse_weights) as scope: params['conv_defs'] = _MOBILENET_V1_100_CONV_NO_LAST_STRIDE_DEFS
params = {} _, activations = mobilenet_v1.mobilenet_v1_base(
if self._skip_last_stride: preprocessed_inputs,
params['conv_defs'] = _MOBILENET_V1_100_CONV_NO_LAST_STRIDE_DEFS final_endpoint='Conv2d_11_pointwise',
_, activations = mobilenet_v1.mobilenet_v1_base( min_depth=self._min_depth,
preprocessed_inputs, depth_multiplier=self._depth_multiplier,
final_endpoint='Conv2d_11_pointwise', scope=scope,
min_depth=self._min_depth, **params)
depth_multiplier=self._depth_multiplier,
scope=scope,
**params)
return activations['Conv2d_11_pointwise'], activations return activations['Conv2d_11_pointwise'], activations
def _extract_box_classifier_features(self, proposal_feature_maps, scope): def _extract_box_classifier_features(self, proposal_feature_maps, scope):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment