Commit 33a4c064 authored by syiming's avatar syiming
Browse files

draft for faster rcnn resnet v1 fpn feature extractor

parent b1025b3b
......@@ -88,8 +88,10 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
`conv_hyperparams`.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16.
"""
if first_stage_features_stride != 8 and first_stage_features_stride != 16:
raise ValueError('`first_stage_features_stride` must be 8 or 16.')
super(FasterRCNNResnetV1FPNKerasFeatureExtractor, self).__init__(
is_training=is_training,
first_stage_features_stride=first_stage_features_stride,
......@@ -109,39 +111,38 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
self._resnet_block_names = ['block1', 'block2', 'block3', 'block4']
self.classification_backbone = None
self._fpn_features_generator = None
self._coarse_feature_layers = []
def build(self,):
# TODO: Refine doc string
"""Build Resnet V1 FPN architecture."""
full_resnet_v1_model = self._resnet_v1_base_model(
batchnorm_training=self._train_batch_norm,
conv_hyperparams=(self._conv_hyperparams
if self._override_base_feature_extractor_hyperparams
else None),
min_depth=self._min_depth,
depth_multiplier=self._depth_multiplier,
classes=None,
weights=None,
include_top=False)
output_layers = _RESNET_MODEL_OUTPUT_LAYERS[self._resnet_v1_base_model_name]
outputs = [full_resnet_v1_model.get_layer(output_layer_name).output
for output_layer_name in output_layers]
self.classification_backbone = tf.keras.Model(
inputs=full_resnet_v1_model.inputs,
outputs=outputs)
self._depth_fn = lambda d: max(
int(d * self._depth_multiplier), self._min_depth)
self._base_fpn_max_level = min(self._fpn_max_level, 5)
self._num_levels = self._base_fpn_max_level + 1 - self._fpn_min_level
self._fpn_features_generator = (
feature_map_generators.KerasFpnTopDownFeatureMaps(
num_levels=self._num_levels,
depth=self._depth_fn(self._additional_layer_depth),
is_training=self._is_training,
conv_hyperparams=self._conv_hyperparams,
freeze_batchnorm=self._freeze_batchnorm,
name='FeatureMaps'))
# full_resnet_v1_model = self._resnet_v1_base_model(
# batchnorm_training=self._train_batch_norm,
# conv_hyperparams=(self._conv_hyperparams
# if self._override_base_feature_extractor_hyperparams
# else None),
# min_depth=self._min_depth,
# depth_multiplier=self._depth_multiplier,
# classes=None,
# weights=None,
# include_top=False)
# output_layers = _RESNET_MODEL_OUTPUT_LAYERS[self._resnet_v1_base_model_name]
# outputs = [full_resnet_v1_model.get_layer(output_layer_name).output
# for output_layer_name in output_layers]
# self.classification_backbone = tf.keras.Model(
# inputs=full_resnet_v1_model.inputs,
# outputs=outputs)
# self._depth_fn = lambda d: max(
# int(d * self._depth_multiplier), self._min_depth)
# self._base_fpn_max_level = min(self._fpn_max_level, 5)
# self._num_levels = self._base_fpn_max_level + 1 - self._fpn_min_level
# self._fpn_features_generator = (
# feature_map_generators.KerasFpnTopDownFeatureMaps(
# num_levels=self._num_levels,
# depth=self._depth_fn(self._additional_layer_depth),
# is_training=self._is_training,
# conv_hyperparams=self._conv_hyperparams,
# freeze_batchnorm=self._freeze_batchnorm,
# name='FeatureMaps'))
# Construct coarse feature layers
depth = self._depth_fn(self._additional_layer_depth)
for i in range(self._base_fpn_max_level, self._fpn_max_level):
......@@ -188,15 +189,74 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
else:
return resized_inputs
def _extract_proposal_features(self, preprocessed_inputs, scope=None):
# TODO: doc string
""""""
preprocessed_inputs = shape_utils.check_min_image_dim(
129, preprocessed_inputs)
# def _extract_proposal_features(self, preprocessed_inputs, scope=None):
# # TODO: doc string
# """"""
# preprocessed_inputs = shape_utils.check_min_image_dim(
# 129, preprocessed_inputs)
# with tf.name_scope(scope):
# with tf.name_scope('ResnetV1FPN'):
# image_features = self.classification_backbone(preprocessed_inputs)
# feature_block_list = []
# for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
# feature_block_list.append('block{}'.format(level - 1))
# feature_block_map = dict(
# list(zip(self._resnet_block_names, image_features)))
# fpn_input_image_features = [
# (feature_block, feature_block_map[feature_block])
# for feature_block in feature_block_list]
# fpn_features = self._fpn_features_generator(fpn_input_image_features)
# return fpn_features
def get_proposal_feature_extractor_model(self, name=None):
"""Returns a model that extracts first stage RPN features.
Extracts features using the first half of the Resnet v1 network.
Args:
name: A scope name to construct all variables within.
with tf.name_scope(scope):
Returns:
A Keras model that takes preprocessed_inputs:
A [batch, height, width, channels] float32 tensor
representing a batch of images.
"""
with tf.name_scope(name):
with tf.name_scope('ResnetV1FPN'):
image_features = self.classification_backbone(preprocessed_inputs)
full_resnet_v1_model = self._resnet_v1_base_model(
batchnorm_training=self._train_batch_norm,
conv_hyperparams=(self._conv_hyperparams
if self._override_base_feature_extractor_hyperparams
else None),
min_depth=self._min_depth,
depth_multiplier=self._depth_multiplier,
classes=None,
weights=None,
include_top=False)
output_layers = _RESNET_MODEL_OUTPUT_LAYERS[self._resnet_v1_base_model_name]
outputs = [full_resnet_v1_model.get_layer(output_layer_name).output
for output_layer_name in output_layers]
self.classification_backbone = tf.keras.Model(
inputs=full_resnet_v1_model.inputs,
outputs=outputs)
backbone_outputs = self.classification_backbone(full_resnet_v1_model.inputs)
# construct FPN feature generator
self._depth_fn = lambda d: max(
int(d * self._depth_multiplier), self._min_depth)
self._base_fpn_max_level = min(self._fpn_max_level, 5)
self._num_levels = self._base_fpn_max_level + 1 - self._fpn_min_level
self._fpn_features_generator = (
feature_map_generators.KerasFpnTopDownFeatureMaps(
num_levels=self._num_levels,
depth=self._depth_fn(self._additional_layer_depth),
is_training=self._is_training,
conv_hyperparams=self._conv_hyperparams,
freeze_batchnorm=self._freeze_batchnorm,
name='FeatureMaps'))
feature_block_list = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
......@@ -208,22 +268,24 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
for feature_block in feature_block_list]
fpn_features = self._fpn_features_generator(fpn_input_image_features)
return fpn_features
feature_extractor_model = tf.keras.models.Model(
inputs=self.full_resnet_v1_model.inputs, outputs=fpn_features)
return feature_extractor_model
def _extract_box_classifier_features(self, proposal_feature_maps, scope=None):
with tf.name_scope(scope):
with tf.name_scope('ResnetV1FPN'):
feature_maps = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_maps.append(proposal_feature_maps['top_down_block{}'.format(level-1)])
self.last_feature_map = proposal_feature_maps['top_down_block{}'.format(
self._base_fpn_max_level - 1)]
# def _extract_box_classifier_features(self, proposal_feature_maps, scope=None):
# with tf.name_scope(scope):
# with tf.name_scope('ResnetV1FPN'):
# feature_maps = []
# for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
# feature_maps.append(proposal_feature_maps['top_down_block{}'.format(level-1)])
# self.last_feature_map = proposal_feature_maps['top_down_block{}'.format(
# self._base_fpn_max_level - 1)]
for coarse_feature_layers in self._coarse_feature_layers:
for layer in coarse_feature_layers:
last_feature_map = layer(last_feature_map)
feature_maps.append(self.last_feature_map)
# for coarse_feature_layers in self._coarse_feature_layers:
# for layer in coarse_feature_layers:
# last_feature_map = layer(last_feature_map)
# feature_maps.append(self.last_feature_map)
return feature_maps
# return feature_maps
def get_box_classifier_feature_extractor_model(self, name=None):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment