Commit 1ea5e1f6 authored by TF Object Detection Team's avatar TF Object Detection Team
Browse files

Merge pull request #8893 from syiming:move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

PiperOrigin-RevId: 324632246
parents 507a8d3c ea8cc8cf
...@@ -20,6 +20,7 @@ import tensorflow.compat.v1 as tf ...@@ -20,6 +20,7 @@ import tensorflow.compat.v1 as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.models import feature_map_generators from object_detection.models import feature_map_generators
from object_detection.models.keras_models import resnet_v1 from object_detection.models.keras_models import resnet_v1
from object_detection.utils import ops
_RESNET_MODEL_OUTPUT_LAYERS = { _RESNET_MODEL_OUTPUT_LAYERS = {
...@@ -32,6 +33,78 @@ _RESNET_MODEL_OUTPUT_LAYERS = { ...@@ -32,6 +33,78 @@ _RESNET_MODEL_OUTPUT_LAYERS = {
} }
class _ResnetFPN(tf.keras.layers.Layer):
"""Construct Resnet FPN layer."""
def __init__(self,
backbone_classifier,
fpn_features_generator,
coarse_feature_layers,
pad_to_multiple,
fpn_min_level,
resnet_block_names,
base_fpn_max_level):
"""Constructor.
Args:
backbone_classifier: Classifier backbone. Should be one of 'resnet_v1_50',
'resnet_v1_101', 'resnet_v1_152'.
fpn_features_generator: KerasFpnTopDownFeatureMaps that accepts a
dictionary of features and returns a ordered dictionary of fpn features.
coarse_feature_layers: Coarse feature layers for fpn.
pad_to_multiple: An integer multiple to pad input image.
fpn_min_level: the highest resolution feature map to use in FPN. The valid
values are {2, 3, 4, 5} which map to Resnet v1 layers.
resnet_block_names: a list of block names of resnet.
base_fpn_max_level: maximum level of fpn without coarse feature layers.
"""
super(_ResnetFPN, self).__init__()
self.classification_backbone = backbone_classifier
self.fpn_features_generator = fpn_features_generator
self.coarse_feature_layers = coarse_feature_layers
self.pad_to_multiple = pad_to_multiple
self._fpn_min_level = fpn_min_level
self._resnet_block_names = resnet_block_names
self._base_fpn_max_level = base_fpn_max_level
def call(self, inputs):
"""Create internal Resnet FPN layer.
Args:
inputs: A [batch, height_out, width_out, channels] float32 tensor
representing a batch of images.
Returns:
feature_maps: A list of tensors with shape [batch, height, width, depth]
represent extracted features.
"""
inputs = ops.pad_to_multiple(inputs, self.pad_to_multiple)
backbone_outputs = self.classification_backbone(inputs)
feature_block_list = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_block_list.append('block{}'.format(level - 1))
feature_block_map = dict(
list(zip(self._resnet_block_names, backbone_outputs)))
fpn_input_image_features = [
(feature_block, feature_block_map[feature_block])
for feature_block in feature_block_list]
fpn_features = self.fpn_features_generator(fpn_input_image_features)
feature_maps = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_maps.append(fpn_features['top_down_block{}'.format(level-1)])
last_feature_map = fpn_features['top_down_block{}'.format(
self._base_fpn_max_level - 1)]
for coarse_feature_layers in self.coarse_feature_layers:
for layer in coarse_feature_layers:
last_feature_map = layer(last_feature_map)
feature_maps.append(last_feature_map)
return feature_maps
class FasterRCNNResnetV1FpnKerasFeatureExtractor( class FasterRCNNResnetV1FpnKerasFeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor): faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor):
"""Faster RCNN Feature Extractor using Keras-based Resnet V1 FPN features.""" """Faster RCNN Feature Extractor using Keras-based Resnet V1 FPN features."""
...@@ -42,7 +115,8 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor( ...@@ -42,7 +115,8 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
resnet_v1_base_model_name, resnet_v1_base_model_name,
first_stage_features_stride, first_stage_features_stride,
conv_hyperparams, conv_hyperparams,
batch_norm_trainable=False, batch_norm_trainable=True,
pad_to_multiple=32,
weight_decay=0.0, weight_decay=0.0,
fpn_min_level=2, fpn_min_level=2,
fpn_max_level=6, fpn_max_level=6,
...@@ -60,6 +134,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor( ...@@ -60,6 +134,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
containing convolution hyperparameters for the layers added on top of containing convolution hyperparameters for the layers added on top of
the base feature extractor. the base feature extractor.
batch_norm_trainable: See base class. batch_norm_trainable: See base class.
pad_to_multiple: An integer multiple to pad input image.
weight_decay: See base class. weight_decay: See base class.
fpn_min_level: the highest resolution feature map to use in FPN. The valid fpn_min_level: the highest resolution feature map to use in FPN. The valid
values are {2, 3, 4, 5} which map to Resnet v1 layers. values are {2, 3, 4, 5} which map to Resnet v1 layers.
...@@ -93,6 +168,8 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor( ...@@ -93,6 +168,8 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
self._fpn_max_level = fpn_max_level self._fpn_max_level = fpn_max_level
self._additional_layer_depth = additional_layer_depth self._additional_layer_depth = additional_layer_depth
self._freeze_batchnorm = (not batch_norm_trainable) self._freeze_batchnorm = (not batch_norm_trainable)
self._pad_to_multiple = pad_to_multiple
self._override_base_feature_extractor_hyperparams = \ self._override_base_feature_extractor_hyperparams = \
override_base_feature_extractor_hyperparams override_base_feature_extractor_hyperparams
self._resnet_block_names = ['block1', 'block2', 'block3', 'block4'] self._resnet_block_names = ['block1', 'block2', 'block3', 'block4']
...@@ -156,10 +233,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor( ...@@ -156,10 +233,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
self.classification_backbone = tf.keras.Model( self.classification_backbone = tf.keras.Model(
inputs=full_resnet_v1_model.inputs, inputs=full_resnet_v1_model.inputs,
outputs=outputs) outputs=outputs)
backbone_outputs = self.classification_backbone(
full_resnet_v1_model.inputs)
# construct FPN feature generator
self._base_fpn_max_level = min(self._fpn_max_level, 5) self._base_fpn_max_level = min(self._fpn_max_level, 5)
self._num_levels = self._base_fpn_max_level + 1 - self._fpn_min_level self._num_levels = self._base_fpn_max_level + 1 - self._fpn_min_level
self._fpn_features_generator = ( self._fpn_features_generator = (
...@@ -171,16 +245,6 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor( ...@@ -171,16 +245,6 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
freeze_batchnorm=self._freeze_batchnorm, freeze_batchnorm=self._freeze_batchnorm,
name='FeatureMaps')) name='FeatureMaps'))
feature_block_list = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_block_list.append('block{}'.format(level - 1))
feature_block_map = dict(
list(zip(self._resnet_block_names, backbone_outputs)))
fpn_input_image_features = [
(feature_block, feature_block_map[feature_block])
for feature_block in feature_block_list]
fpn_features = self._fpn_features_generator(fpn_input_image_features)
# Construct coarse feature layers # Construct coarse feature layers
for i in range(self._base_fpn_max_level, self._fpn_max_level): for i in range(self._base_fpn_max_level, self._fpn_max_level):
layers = [] layers = []
...@@ -202,19 +266,13 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor( ...@@ -202,19 +266,13 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
name=layer_name)) name=layer_name))
self._coarse_feature_layers.append(layers) self._coarse_feature_layers.append(layers)
feature_maps = [] feature_extractor_model = _ResnetFPN(self.classification_backbone,
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1): self._fpn_features_generator,
feature_maps.append(fpn_features['top_down_block{}'.format(level-1)]) self._coarse_feature_layers,
last_feature_map = fpn_features['top_down_block{}'.format( self._pad_to_multiple,
self._base_fpn_max_level - 1)] self._fpn_min_level,
self._resnet_block_names,
for coarse_feature_layers in self._coarse_feature_layers: self._base_fpn_max_level)
for layer in coarse_feature_layers:
last_feature_map = layer(last_feature_map)
feature_maps.append(last_feature_map)
feature_extractor_model = tf.keras.models.Model(
inputs=full_resnet_v1_model.inputs, outputs=feature_maps)
return feature_extractor_model return feature_extractor_model
def get_box_classifier_feature_extractor_model(self, name=None): def get_box_classifier_feature_extractor_model(self, name=None):
...@@ -233,16 +291,18 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor( ...@@ -233,16 +291,18 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
And returns proposal_classifier_features: And returns proposal_classifier_features:
A 4-D float tensor with shape A 4-D float tensor with shape
[batch_size * self.max_num_proposals, 1024] [batch_size * self.max_num_proposals, 1, 1, 1024]
representing box classifier features for each proposal. representing box classifier features for each proposal.
""" """
with tf.name_scope(name): with tf.name_scope(name):
with tf.name_scope('ResnetV1FPN'): with tf.name_scope('ResnetV1FPN'):
# TODO(yiming): Add a batchnorm layer between two fc layers.
feature_extractor_model = tf.keras.models.Sequential([ feature_extractor_model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(), tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=1024, activation='relu'), tf.keras.layers.Dense(units=1024, activation='relu'),
tf.keras.layers.Dense(units=1024, activation='relu') self._conv_hyperparams.build_batch_norm(
training=(self._is_training and not self._freeze_batchnorm)),
tf.keras.layers.Dense(units=1024, activation='relu'),
tf.keras.layers.Reshape((1, 1, 1024))
]) ])
return feature_extractor_model return feature_extractor_model
...@@ -254,8 +314,8 @@ class FasterRCNNResnet50FpnKerasFeatureExtractor( ...@@ -254,8 +314,8 @@ class FasterRCNNResnet50FpnKerasFeatureExtractor(
def __init__(self, def __init__(self,
is_training, is_training,
first_stage_features_stride=16, first_stage_features_stride=16,
batch_norm_trainable=True,
conv_hyperparams=None, conv_hyperparams=None,
batch_norm_trainable=False,
weight_decay=0.0, weight_decay=0.0,
fpn_min_level=2, fpn_min_level=2,
fpn_max_level=6, fpn_max_level=6,
...@@ -266,8 +326,8 @@ class FasterRCNNResnet50FpnKerasFeatureExtractor( ...@@ -266,8 +326,8 @@ class FasterRCNNResnet50FpnKerasFeatureExtractor(
Args: Args:
is_training: See base class. is_training: See base class.
first_stage_features_stride: See base class. first_stage_features_stride: See base class.
conv_hyperparams: See base class.
batch_norm_trainable: See base class. batch_norm_trainable: See base class.
conv_hyperparams: See base class.
weight_decay: See base class. weight_decay: See base class.
fpn_min_level: See base class. fpn_min_level: See base class.
fpn_max_level: See base class. fpn_max_level: See base class.
...@@ -297,8 +357,8 @@ class FasterRCNNResnet101FpnKerasFeatureExtractor( ...@@ -297,8 +357,8 @@ class FasterRCNNResnet101FpnKerasFeatureExtractor(
def __init__(self, def __init__(self,
is_training, is_training,
first_stage_features_stride=16, first_stage_features_stride=16,
batch_norm_trainable=True,
conv_hyperparams=None, conv_hyperparams=None,
batch_norm_trainable=False,
weight_decay=0.0, weight_decay=0.0,
fpn_min_level=2, fpn_min_level=2,
fpn_max_level=6, fpn_max_level=6,
...@@ -309,8 +369,8 @@ class FasterRCNNResnet101FpnKerasFeatureExtractor( ...@@ -309,8 +369,8 @@ class FasterRCNNResnet101FpnKerasFeatureExtractor(
Args: Args:
is_training: See base class. is_training: See base class.
first_stage_features_stride: See base class. first_stage_features_stride: See base class.
conv_hyperparams: See base class.
batch_norm_trainable: See base class. batch_norm_trainable: See base class.
conv_hyperparams: See base class.
weight_decay: See base class. weight_decay: See base class.
fpn_min_level: See base class. fpn_min_level: See base class.
fpn_max_level: See base class. fpn_max_level: See base class.
...@@ -339,8 +399,8 @@ class FasterRCNNResnet152FpnKerasFeatureExtractor( ...@@ -339,8 +399,8 @@ class FasterRCNNResnet152FpnKerasFeatureExtractor(
def __init__(self, def __init__(self,
is_training, is_training,
first_stage_features_stride=16, first_stage_features_stride=16,
batch_norm_trainable=True,
conv_hyperparams=None, conv_hyperparams=None,
batch_norm_trainable=False,
weight_decay=0.0, weight_decay=0.0,
fpn_min_level=2, fpn_min_level=2,
fpn_max_level=6, fpn_max_level=6,
...@@ -351,8 +411,8 @@ class FasterRCNNResnet152FpnKerasFeatureExtractor( ...@@ -351,8 +411,8 @@ class FasterRCNNResnet152FpnKerasFeatureExtractor(
Args: Args:
is_training: See base class. is_training: See base class.
first_stage_features_stride: See base class. first_stage_features_stride: See base class.
conv_hyperparams: See base class.
batch_norm_trainable: See base class. batch_norm_trainable: See base class.
conv_hyperparams: See base class.
weight_decay: See base class. weight_decay: See base class.
fpn_min_level: See base class. fpn_min_level: See base class.
fpn_max_level: See base class. fpn_max_level: See base class.
......
...@@ -91,4 +91,4 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractorTest(tf.test.TestCase): ...@@ -91,4 +91,4 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractorTest(tf.test.TestCase):
model(proposal_feature_maps)) model(proposal_feature_maps))
features_shape = tf.shape(proposal_classifier_features) features_shape = tf.shape(proposal_classifier_features)
self.assertAllEqual(features_shape.numpy(), [3, 1024]) self.assertAllEqual(features_shape.numpy(), [3, 1, 1, 1024])
...@@ -216,13 +216,13 @@ def pad_to_multiple(tensor, multiple): ...@@ -216,13 +216,13 @@ def pad_to_multiple(tensor, multiple):
height_pad = tf.zeros([ height_pad = tf.zeros([
batch_size, padded_tensor_height - tensor_height, tensor_width, batch_size, padded_tensor_height - tensor_height, tensor_width,
tensor_depth tensor_depth
]) ], dtype=tensor.dtype)
tensor = tf.concat([tensor, height_pad], 1) tensor = tf.concat([tensor, height_pad], 1)
if padded_tensor_width != tensor_width: if padded_tensor_width != tensor_width:
width_pad = tf.zeros([ width_pad = tf.zeros([
batch_size, padded_tensor_height, padded_tensor_width - tensor_width, batch_size, padded_tensor_height, padded_tensor_width - tensor_width,
tensor_depth tensor_depth
]) ], dtype=tensor.dtype)
tensor = tf.concat([tensor, width_pad], 2) tensor = tf.concat([tensor, width_pad], 2)
return tensor return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment