Commit 9733eeb0 authored by Soroosh Yazdani's avatar Soroosh Yazdani Committed by TF Object Detection Team
Browse files

Updating center_net_mobilenet_v2_fpn_feature_extractor to support classification finetuning.

PiperOrigin-RevId: 351619683
parent 63ec7359
...@@ -58,18 +58,18 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -58,18 +58,18 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means=channel_means, channel_means=channel_means,
channel_stds=channel_stds, channel_stds=channel_stds,
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
self._network = mobilenet_v2_net self._base_model = mobilenet_v2_net
output = self._network(self._network.input) output = self._base_model(self._base_model.input)
# Add pyramid feature network on every layer that has stride 2. # Add pyramid feature network on every layer that has stride 2.
skip_outputs = [ skip_outputs = [
self._network.get_layer(skip_layer_name).output self._base_model.get_layer(skip_layer_name).output
for skip_layer_name in _MOBILENET_V2_FPN_SKIP_LAYERS for skip_layer_name in _MOBILENET_V2_FPN_SKIP_LAYERS
] ]
self._fpn_model = tf.keras.models.Model( self._fpn_model = tf.keras.models.Model(
inputs=self._network.input, outputs=skip_outputs) inputs=self._base_model.input, outputs=skip_outputs)
fpn_outputs = self._fpn_model(self._network.input) fpn_outputs = self._fpn_model(self._base_model.input)
# Construct the top-down feature maps -- we start with an output of # Construct the top-down feature maps -- we start with an output of
# 7x7x1280, which we continually upsample, apply a residual on and merge. # 7x7x1280, which we continually upsample, apply a residual on and merge.
...@@ -108,8 +108,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -108,8 +108,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
output = top_down output = top_down
self._network = tf.keras.models.Model( self._feature_extractor_model = tf.keras.models.Model(
inputs=self._network.input, outputs=output) inputs=self._base_model.input, outputs=output)
def preprocess(self, resized_inputs): def preprocess(self, resized_inputs):
resized_inputs = super(CenterNetMobileNetV2FPNFeatureExtractor, resized_inputs = super(CenterNetMobileNetV2FPNFeatureExtractor,
...@@ -117,13 +117,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -117,13 +117,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
return tf.keras.applications.mobilenet_v2.preprocess_input(resized_inputs) return tf.keras.applications.mobilenet_v2.preprocess_input(resized_inputs)
def load_feature_extractor_weights(self, path): def load_feature_extractor_weights(self, path):
self._network.load_weights(path) self._base_model.load_weights(path)
def get_base_model(self): @property
return self._network def supported_sub_model_types(self):
return ['classification']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def call(self, inputs): def call(self, inputs):
return [self._network(inputs)] return [self._feature_extractor_model(inputs)]
@property @property
def out_stride(self): def out_stride(self):
...@@ -135,9 +142,6 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -135,9 +142,6 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
"""The number of feature outputs returned by the feature extractor.""" """The number of feature outputs returned by the feature extractor."""
return 1 return 1
def get_model(self):
return self._network
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering): def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering):
"""The MobileNetV2+FPN backbone for CenterNet.""" """The MobileNetV2+FPN backbone for CenterNet."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment