Commit 145ac875 authored by syiming's avatar syiming
Browse files

fix coding style

parent 0e0f739b
...@@ -35,15 +35,15 @@ _RESNET_MODEL_OUTPUT_LAYERS = { ...@@ -35,15 +35,15 @@ _RESNET_MODEL_OUTPUT_LAYERS = {
class ResnetFPN(tf.keras.layers.Layer): class ResnetFPN(tf.keras.layers.Layer):
"""Construct Resnet FPN layer.""" """Construct Resnet FPN layer."""
def __init__(self, def __init__(self,
backbone_classifier, backbone_classifier,
fpn_features_generator, fpn_features_generator,
coarse_feature_layers, coarse_feature_layers,
pad_to_multiple, pad_to_multiple,
fpn_min_level, fpn_min_level,
resnet_block_names, resnet_block_names,
base_fpn_max_level): base_fpn_max_level):
"""Constructor. """Constructor.
Args: Args:
backbone_classifier: Classifier backbone. Should be one of 'resnet_v1_50', backbone_classifier: Classifier backbone. Should be one of 'resnet_v1_50',
...@@ -56,41 +56,51 @@ class ResnetFPN(tf.keras.layers.Layer): ...@@ -56,41 +56,51 @@ class ResnetFPN(tf.keras.layers.Layer):
resnet_block_names: a list of block names of resnet. resnet_block_names: a list of block names of resnet.
base_fpn_max_level: maximum level of fpn without coarse feature layers. base_fpn_max_level: maximum level of fpn without coarse feature layers.
""" """
super(ResnetFPN, self).__init__() super(ResnetFPN, self).__init__()
self.classification_backbone = backbone_classifier self.classification_backbone = backbone_classifier
self.fpn_features_generator = fpn_features_generator self.fpn_features_generator = fpn_features_generator
self.coarse_feature_layers = coarse_feature_layers self.coarse_feature_layers = coarse_feature_layers
self.pad_to_multiple = pad_to_multiple self.pad_to_multiple = pad_to_multiple
self._fpn_min_level = fpn_min_level self._fpn_min_level = fpn_min_level
self._resnet_block_names = resnet_block_names self._resnet_block_names = resnet_block_names
self._base_fpn_max_level = base_fpn_max_level self._base_fpn_max_level = base_fpn_max_level
def call(self, inputs): def call(self, inputs):
inputs = ops.pad_to_multiple(inputs, self.pad_to_multiple) """Create ResnetFPN layer.
backbone_outputs = self.classification_backbone(inputs)
Args:
feature_block_list = [] inputs: A [batch, height_out, width_out, channels] float32 tensor
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1): representing a batch of images.
feature_block_list.append('block{}'.format(level - 1))
feature_block_map = dict( Return:
list(zip(self._resnet_block_names, backbone_outputs))) feature_maps: A list of tensors with shape [batch, height, width, depth]
fpn_input_image_features = [ represent extracted features.
(feature_block, feature_block_map[feature_block]) """
for feature_block in feature_block_list] inputs = ops.pad_to_multiple(inputs, self.pad_to_multiple)
fpn_features = self.fpn_features_generator(fpn_input_image_features) backbone_outputs = self.classification_backbone(inputs)
feature_maps = [] feature_block_list = []
for level in range(self._fpn_min_level, self._base_fpn_max_level + 1): for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_maps.append(fpn_features['top_down_block{}'.format(level-1)]) feature_block_list.append('block{}'.format(level - 1))
last_feature_map = fpn_features['top_down_block{}'.format( feature_block_map = dict(
self._base_fpn_max_level - 1)] list(zip(self._resnet_block_names, backbone_outputs)))
fpn_input_image_features = [
for coarse_feature_layers in self.coarse_feature_layers: (feature_block, feature_block_map[feature_block])
for layer in coarse_feature_layers: for feature_block in feature_block_list]
last_feature_map = layer(last_feature_map) fpn_features = self.fpn_features_generator(fpn_input_image_features)
feature_maps.append(last_feature_map)
feature_maps = []
return feature_maps for level in range(self._fpn_min_level, self._base_fpn_max_level + 1):
feature_maps.append(fpn_features['top_down_block{}'.format(level-1)])
last_feature_map = fpn_features['top_down_block{}'.format(
self._base_fpn_max_level - 1)]
for coarse_feature_layers in self.coarse_feature_layers:
for layer in coarse_feature_layers:
last_feature_map = layer(last_feature_map)
feature_maps.append(last_feature_map)
return feature_maps
class FasterRCNNResnetV1FpnKerasFeatureExtractor( class FasterRCNNResnetV1FpnKerasFeatureExtractor(
...@@ -229,7 +239,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor( ...@@ -229,7 +239,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
conv_hyperparams=self._conv_hyperparams, conv_hyperparams=self._conv_hyperparams,
freeze_batchnorm=self._freeze_batchnorm, freeze_batchnorm=self._freeze_batchnorm,
name='FeatureMaps')) name='FeatureMaps'))
# Construct coarse feature layers # Construct coarse feature layers
for i in range(self._base_fpn_max_level, self._fpn_max_level): for i in range(self._base_fpn_max_level, self._fpn_max_level):
layers = [] layers = []
...@@ -250,14 +260,14 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor( ...@@ -250,14 +260,14 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
self._conv_hyperparams.build_activation_layer( self._conv_hyperparams.build_activation_layer(
name=layer_name)) name=layer_name))
self._coarse_feature_layers.append(layers) self._coarse_feature_layers.append(layers)
feature_extractor_model = ResnetFPN(self.classification_backbone, feature_extractor_model = ResnetFPN(self.classification_backbone,
self._fpn_features_generator, self._fpn_features_generator,
self._coarse_feature_layers, self._coarse_feature_layers,
self._pad_to_multiple, self._pad_to_multiple,
self._fpn_min_level, self._fpn_min_level,
self._resnet_block_names, self._resnet_block_names,
self._base_fpn_max_level) self._base_fpn_max_level)
return feature_extractor_model return feature_extractor_model
def get_box_classifier_feature_extractor_model(self, name=None): def get_box_classifier_feature_extractor_model(self, name=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment