Commit 1cf70ed7 authored by syiming's avatar syiming
Browse files

Fix coding style for feature extractor.

parent 7140eede
...@@ -20,9 +20,7 @@ import tensorflow.compat.v1 as tf ...@@ -20,9 +20,7 @@ import tensorflow.compat.v1 as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.models import feature_map_generators from object_detection.models import feature_map_generators
from object_detection.models.keras_models import resnet_v1 from object_detection.models.keras_models import resnet_v1
from object_detection.models.keras_models import model_utils
from object_detection.utils import ops
from object_detection.utils import shape_utils
_RESNET_MODEL_OUTPUT_LAYERS = { _RESNET_MODEL_OUTPUT_LAYERS = {
'resnet_v1_50': ['conv2_block3_out', 'conv3_block4_out', 'resnet_v1_50': ['conv2_block3_out', 'conv3_block4_out',
...@@ -35,7 +33,7 @@ _RESNET_MODEL_OUTPUT_LAYERS = { ...@@ -35,7 +33,7 @@ _RESNET_MODEL_OUTPUT_LAYERS = {
class FasterRCNNResnetV1FPNKerasFeatureExtractor( class FasterRCNNResnetV1FPNKerasFeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor): faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor):
"""Faster RCNN Feature Extractor using Keras-based Resnet V1 FPN features.""" """Faster RCNN Feature Extractor using Keras-based Resnet V1 FPN features."""
def __init__(self, def __init__(self,
...@@ -52,30 +50,24 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -52,30 +50,24 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
fpn_max_level=7, fpn_max_level=7,
additional_layer_depth=256, additional_layer_depth=256,
override_base_feature_extractor_hyperparams=False): override_base_feature_extractor_hyperparams=False):
# FIXME: fix doc string for fpn min level and fpn max level
"""Constructor. """Constructor.
Args: Args:
is_training: See base class. is_training: See base class.
resnet_v1_base_model: base resnet v1 network to use. One of
the resnet_v1.resnet_v1_{50,101,152} models.
resnet_v1_base_model_name: model name under which to construct resnet v1.
first_stage_features_stride: See base class. first_stage_features_stride: See base class.
conv_hyperparameters: a `hyperparams_builder.KerasLayerHyperparams` object conv_hyperparameters: a `hyperparams_builder.KerasLayerHyperparams` object
containing convolution hyperparameters for the layers added on top of containing convolution hyperparameters for the layers added on top of
the base feature extractor. the base feature extractor.
min_depth: Minimum number of filters in the convolutional layers. min_depth: Minimum number of filters in the convolutional layers.
depth_multiplier: The depth multiplier to modify the number of filters depth_multiplier: The depth multiplier to modify the number of filters
in the convolutional layers. in the convolutional layers.
resnet_v1_base_model: base resnet v1 network to use. One of
the resnet_v1.resnet_v1_{50,101,152} models.
resnet_v1_base_model_name: model name under which to construct resnet v1.
batch_norm_trainable: See base class. batch_norm_trainable: See base class.
weight_decay: See base class. weight_decay: See base class.
fpn_min_level: the highest resolution feature map to use in FPN. The valid fpn_min_level: the highest resolution feature map to use in FPN. The valid
values are {2, 3, 4, 5} which map to MobileNet v1 layers values are {2, 3, 4, 5} which map to Resnet v1 layers.
{Conv2d_3_pointwise, Conv2d_5_pointwise, Conv2d_11_pointwise,
Conv2d_13_pointwise}, respectively.
fpn_max_level: the smallest resolution feature map to construct or use in fpn_max_level: the smallest resolution feature map to construct or use in
FPN. FPN constructions uses features maps starting from fpn_min_level FPN. FPN constructions uses features maps starting from fpn_min_level
upto the fpn_max_level. In the case that there are not enough feature upto the fpn_max_level. In the case that there are not enough feature
...@@ -92,22 +84,24 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -92,22 +84,24 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
""" """
if first_stage_features_stride != 8 and first_stage_features_stride != 16: if first_stage_features_stride != 8 and first_stage_features_stride != 16:
raise ValueError('`first_stage_features_stride` must be 8 or 16.') raise ValueError('`first_stage_features_stride` must be 8 or 16.')
super(FasterRCNNResnetV1FPNKerasFeatureExtractor, self).__init__( super(FasterRCNNResnetV1FPNKerasFeatureExtractor, self).__init__(
is_training=is_training, is_training=is_training,
first_stage_features_stride=first_stage_features_stride, first_stage_features_stride=first_stage_features_stride,
batch_norm_trainable=batch_norm_trainable, batch_norm_trainable=batch_norm_trainable,
weight_decay=weight_decay) weight_decay=weight_decay)
self._resnet_v1_base_model = resnet_v1_base_model
self._resnet_v1_base_model_name = resnet_v1_base_model_name
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams = conv_hyperparams
self._min_depth = min_depth self._min_depth = min_depth
self._depth_multiplier = depth_multiplier self._depth_multiplier = depth_multiplier
self._fpn_min_level = fpn_min_level
self._fpn_max_level = fpn_max_level
self._additional_layer_depth = additional_layer_depth self._additional_layer_depth = additional_layer_depth
self._freeze_batchnorm = (not batch_norm_trainable) self._freeze_batchnorm = (not batch_norm_trainable)
self._override_base_feature_extractor_hyperparams = \ self._override_base_feature_extractor_hyperparams = \
override_base_feature_extractor_hyperparams override_base_feature_extractor_hyperparams
self._fpn_min_level = fpn_min_level
self._fpn_max_level = fpn_max_level
self._resnet_v1_base_model = resnet_v1_base_model
self._resnet_v1_base_model_name = resnet_v1_base_model_name
self._resnet_block_names = ['block1', 'block2', 'block3', 'block4'] self._resnet_block_names = ['block1', 'block2', 'block3', 'block4']
self.classification_backbone = None self.classification_backbone = None
self._fpn_features_generator = None self._fpn_features_generator = None
...@@ -134,11 +128,11 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -134,11 +128,11 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
return resized_inputs - [[channel_means]] return resized_inputs - [[channel_means]]
else: else:
return resized_inputs return resized_inputs
def get_proposal_feature_extractor_model(self, name=None): def get_proposal_feature_extractor_model(self, name=None):
"""Returns a model that extracts first stage RPN features. """Returns a model that extracts first stage RPN features.
Extracts features using the first half of the Resnet v1 network. Extracts features using the Resnet v1 FPN network.
Args: Args:
name: A scope name to construct all variables within. name: A scope name to construct all variables within.
...@@ -147,6 +141,9 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -147,6 +141,9 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
A Keras model that takes preprocessed_inputs: A Keras model that takes preprocessed_inputs:
A [batch, height, width, channels] float32 tensor A [batch, height, width, channels] float32 tensor
representing a batch of images. representing a batch of images.
And returns rpn_feature_map:
A list of tensors with shape [batch, height, width, depth]
""" """
with tf.name_scope(name): with tf.name_scope(name):
with tf.name_scope('ResnetV1FPN'): with tf.name_scope('ResnetV1FPN'):
...@@ -162,7 +159,7 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -162,7 +159,7 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
include_top=False) include_top=False)
output_layers = _RESNET_MODEL_OUTPUT_LAYERS[self._resnet_v1_base_model_name] output_layers = _RESNET_MODEL_OUTPUT_LAYERS[self._resnet_v1_base_model_name]
outputs = [full_resnet_v1_model.get_layer(output_layer_name).output outputs = [full_resnet_v1_model.get_layer(output_layer_name).output
for output_layer_name in output_layers] for output_layer_name in output_layers]
self.classification_backbone = tf.keras.Model( self.classification_backbone = tf.keras.Model(
inputs=full_resnet_v1_model.inputs, inputs=full_resnet_v1_model.inputs,
outputs=outputs) outputs=outputs)
...@@ -181,7 +178,7 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -181,7 +178,7 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
conv_hyperparams=self._conv_hyperparams, conv_hyperparams=self._conv_hyperparams,
freeze_batchnorm=self._freeze_batchnorm, freeze_batchnorm=self._freeze_batchnorm,
name='FeatureMaps')) name='FeatureMaps'))
feature_block_list = [] feature_block_list = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1): for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_block_list.append('block{}'.format(level - 1)) feature_block_list.append('block{}'.format(level - 1))
...@@ -200,7 +197,7 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -200,7 +197,7 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
def get_box_classifier_feature_extractor_model(self, name=None): def get_box_classifier_feature_extractor_model(self, name=None):
"""Returns a model that extracts second stage box classifier features. """Returns a model that extracts second stage box classifier features.
TODO: doc Construct two fully connected layer to extract the box classifier features.
Args: Args:
name: A scope name to construct all variables within. name: A scope name to construct all variables within.
...@@ -210,17 +207,18 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -210,17 +207,18 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
A 4-D float tensor with shape A 4-D float tensor with shape
[batch_size * self.max_num_proposals, crop_height, crop_width, depth] [batch_size * self.max_num_proposals, crop_height, crop_width, depth]
representing the feature map cropped to each proposal. representing the feature map cropped to each proposal.
And returns proposal_classifier_features: And returns proposal_classifier_features:
A 4-D float tensor with shape A 4-D float tensor with shape
[batch_size * self.max_num_proposals, height, width, depth] [batch_size * self.max_num_proposals, 1024]
representing box classifier features for each proposal. representing box classifier features for each proposal.
""" """
with tf.name_scope(name): with tf.name_scope(name):
with tf.name_scope('ResnetV1FPN'): with tf.name_scope('ResnetV1FPN'):
feature_extractor_model = tf.keras.models.Sequential([ feature_extractor_model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(), tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=1024, activation='relu'), tf.keras.layers.Dense(units=1024, activation='relu'),
tf.keras.layers.Dense(units=1024, activation='relu') tf.keras.layers.Dense(units=1024, activation='relu')
]) ])
return feature_extractor_model return feature_extractor_model
...@@ -228,7 +226,7 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -228,7 +226,7 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
class FasterRCNNResnet50FPNKerasFeatureExtractor( class FasterRCNNResnet50FPNKerasFeatureExtractor(
FasterRCNNResnetV1FPNKerasFeatureExtractor): FasterRCNNResnetV1FPNKerasFeatureExtractor):
"""Faster RCNN with Resnet50 FPN feature extractor implementation.""" """Faster RCNN with Resnet50 FPN feature extractor implementation."""
def __init__(self, def __init__(self,
is_training, is_training,
first_stage_features_stride=16, first_stage_features_stride=16,
...@@ -271,6 +269,7 @@ class FasterRCNNResnet50FPNKerasFeatureExtractor( ...@@ -271,6 +269,7 @@ class FasterRCNNResnet50FPNKerasFeatureExtractor(
additional_layer_depth=additional_layer_depth, additional_layer_depth=additional_layer_depth,
override_base_feature_extractor_hyperparams=override_base_feature_extractor_hyperparams) override_base_feature_extractor_hyperparams=override_base_feature_extractor_hyperparams)
class FasterRCNNResnet101FPNKerasFeatureExtractor( class FasterRCNNResnet101FPNKerasFeatureExtractor(
FasterRCNNResnetV1FPNKerasFeatureExtractor): FasterRCNNResnetV1FPNKerasFeatureExtractor):
"""Faster RCNN with Resnet101 FPN feature extractor implementation.""" """Faster RCNN with Resnet101 FPN feature extractor implementation."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment