Commit 145ac875 authored by syiming's avatar syiming
Browse files

fix coding style

parent 0e0f739b
......@@ -35,15 +35,15 @@ _RESNET_MODEL_OUTPUT_LAYERS = {
class ResnetFPN(tf.keras.layers.Layer):
"""Construct Resnet FPN layer."""
def __init__(self,
backbone_classifier,
fpn_features_generator,
coarse_feature_layers,
pad_to_multiple,
fpn_min_level,
resnet_block_names,
base_fpn_max_level):
"""Constructor.
def __init__(self,
backbone_classifier,
fpn_features_generator,
coarse_feature_layers,
pad_to_multiple,
fpn_min_level,
resnet_block_names,
base_fpn_max_level):
"""Constructor.
Args:
backbone_classifier: Classifier backbone. Should be one of 'resnet_v1_50',
......@@ -56,41 +56,51 @@ class ResnetFPN(tf.keras.layers.Layer):
resnet_block_names: a list of block names of resnet.
base_fpn_max_level: maximum level of fpn without coarse feature layers.
"""
super(ResnetFPN, self).__init__()
self.classification_backbone = backbone_classifier
self.fpn_features_generator = fpn_features_generator
self.coarse_feature_layers = coarse_feature_layers
self.pad_to_multiple = pad_to_multiple
self._fpn_min_level = fpn_min_level
self._resnet_block_names = resnet_block_names
self._base_fpn_max_level = base_fpn_max_level
def call(self, inputs):
inputs = ops.pad_to_multiple(inputs, self.pad_to_multiple)
backbone_outputs = self.classification_backbone(inputs)
feature_block_list = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_block_list.append('block{}'.format(level - 1))
feature_block_map = dict(
list(zip(self._resnet_block_names, backbone_outputs)))
fpn_input_image_features = [
(feature_block, feature_block_map[feature_block])
for feature_block in feature_block_list]
fpn_features = self.fpn_features_generator(fpn_input_image_features)
feature_maps = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_maps.append(fpn_features['top_down_block{}'.format(level-1)])
last_feature_map = fpn_features['top_down_block{}'.format(
self._base_fpn_max_level - 1)]
for coarse_feature_layers in self.coarse_feature_layers:
for layer in coarse_feature_layers:
last_feature_map = layer(last_feature_map)
feature_maps.append(last_feature_map)
return feature_maps
super(ResnetFPN, self).__init__()
self.classification_backbone = backbone_classifier
self.fpn_features_generator = fpn_features_generator
self.coarse_feature_layers = coarse_feature_layers
self.pad_to_multiple = pad_to_multiple
self._fpn_min_level = fpn_min_level
self._resnet_block_names = resnet_block_names
self._base_fpn_max_level = base_fpn_max_level
def call(self, inputs):
"""Create ResnetFPN layer.
Args:
inputs: A [batch, height_out, width_out, channels] float32 tensor
representing a batch of images.
Return:
feature_maps: A list of tensors with shape [batch, height, width, depth]
represent extracted features.
"""
inputs = ops.pad_to_multiple(inputs, self.pad_to_multiple)
backbone_outputs = self.classification_backbone(inputs)
feature_block_list = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_block_list.append('block{}'.format(level - 1))
feature_block_map = dict(
list(zip(self._resnet_block_names, backbone_outputs)))
fpn_input_image_features = [
(feature_block, feature_block_map[feature_block])
for feature_block in feature_block_list]
fpn_features = self.fpn_features_generator(fpn_input_image_features)
feature_maps = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_maps.append(fpn_features['top_down_block{}'.format(level-1)])
last_feature_map = fpn_features['top_down_block{}'.format(
self._base_fpn_max_level - 1)]
for coarse_feature_layers in self.coarse_feature_layers:
for layer in coarse_feature_layers:
last_feature_map = layer(last_feature_map)
feature_maps.append(last_feature_map)
return feature_maps
class FasterRCNNResnetV1FpnKerasFeatureExtractor(
......@@ -229,7 +239,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
conv_hyperparams=self._conv_hyperparams,
freeze_batchnorm=self._freeze_batchnorm,
name='FeatureMaps'))
# Construct coarse feature layers
for i in range(self._base_fpn_max_level, self._fpn_max_level):
layers = []
......@@ -250,14 +260,14 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
self._conv_hyperparams.build_activation_layer(
name=layer_name))
self._coarse_feature_layers.append(layers)
feature_extractor_model = ResnetFPN(self.classification_backbone,
self._fpn_features_generator,
self._coarse_feature_layers,
self._pad_to_multiple,
self._fpn_min_level,
self._resnet_block_names,
self._base_fpn_max_level)
self._fpn_features_generator,
self._coarse_feature_layers,
self._pad_to_multiple,
self._fpn_min_level,
self._resnet_block_names,
self._base_fpn_max_level)
return feature_extractor_model
def get_box_classifier_feature_extractor_model(self, name=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment