"docs/code-docs/source/memory.rst" did not exist on "ab5534fc4c0f8ca21ada321f9730d723aa31288b"
Commit 9c314a03 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Internal change

PiperOrigin-RevId: 321825286
parent ec2d5d8d
...@@ -48,6 +48,7 @@ from object_detection.utils import tf_version ...@@ -48,6 +48,7 @@ from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
if tf_version.is_tf2(): if tf_version.is_tf2():
from object_detection.models import center_net_hourglass_feature_extractor from object_detection.models import center_net_hourglass_feature_extractor
from object_detection.models import center_net_mobilenet_v2_feature_extractor
from object_detection.models import center_net_resnet_feature_extractor from object_detection.models import center_net_resnet_feature_extractor
from object_detection.models import center_net_resnet_v1_fpn_feature_extractor from object_detection.models import center_net_resnet_v1_fpn_feature_extractor
from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras
...@@ -148,7 +149,10 @@ if tf_version.is_tf2(): ...@@ -148,7 +149,10 @@ if tf_version.is_tf2():
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_50_fpn, center_net_resnet_v1_fpn_feature_extractor.resnet_v1_50_fpn,
'resnet_v1_101_fpn': 'resnet_v1_101_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_101_fpn, center_net_resnet_v1_fpn_feature_extractor.resnet_v1_101_fpn,
'hourglass_104': center_net_hourglass_feature_extractor.hourglass_104, 'hourglass_104':
center_net_hourglass_feature_extractor.hourglass_104,
'mobilenet_v2':
center_net_mobilenet_v2_feature_extractor.mobilenet_v2,
} }
FEATURE_EXTRACTOR_MAPS = [ FEATURE_EXTRACTOR_MAPS = [
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MobileNet V2[1] feature extractor for CenterNet[2] meta architecture.
[1]: https://arxiv.org/abs/1801.04381
[2]: https://arxiv.org/abs/1904.07850
"""
import tensorflow.compat.v1 as tf
from object_detection.meta_architectures import center_net_meta_arch
from object_detection.models.keras_models import mobilenet_v2 as mobilenetv2
class CenterNetMobileNetV2FeatureExtractor(
center_net_meta_arch.CenterNetFeatureExtractor):
"""The MobileNet V2 feature extractor for CenterNet."""
def __init__(self,
mobilenet_v2_net,
channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.),
bgr_ordering=False):
"""Intializes the feature extractor.
Args:
mobilenet_v2_net: The underlying mobilenet_v2 network to use.
channel_means: A tuple of floats, denoting the mean of each channel
which will be subtracted from it.
channel_stds: A tuple of floats, denoting the standard deviation of each
channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
"""
super(CenterNetMobileNetV2FeatureExtractor, self).__init__(
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
self._network = mobilenet_v2_net
output = self._network(self._network.input)
# TODO(nkhadke): Try out MobileNet+FPN next (skip connections are cheap and
# should help with performance).
# MobileNet by itself transforms a 224x224x3 volume into a 7x7x1280, which
# leads to a stride of 32. We perform upsampling to get it to a target
# stride of 4.
for num_filters in [256, 128, 64]:
# 1. We use a simple convolution instead of a deformable convolution
conv = tf.keras.layers.Conv2D(
filters=num_filters, kernel_size=1, strides=1, padding='same')
output = conv(output)
output = tf.keras.layers.BatchNormalization()(output)
output = tf.keras.layers.ReLU()(output)
# 2. We use the default initialization for the convolution layers
# instead of initializing it to do bilinear upsampling.
conv_transpose = tf.keras.layers.Conv2DTranspose(
filters=num_filters, kernel_size=3, strides=2, padding='same')
output = conv_transpose(output)
output = tf.keras.layers.BatchNormalization()(output)
output = tf.keras.layers.ReLU()(output)
self._network = tf.keras.models.Model(
inputs=self._network.input, outputs=output)
def preprocess(self, resized_inputs):
resized_inputs = super(CenterNetMobileNetV2FeatureExtractor,
self).preprocess(resized_inputs)
return tf.keras.applications.mobilenet_v2.preprocess_input(resized_inputs)
def load_feature_extractor_weights(self, path):
self._network.load_weights(path)
def get_base_model(self):
return self._network
def call(self, inputs):
return [self._network(inputs)]
@property
def out_stride(self):
"""The stride in the output image of the network."""
return 4
@property
def num_feature_outputs(self):
"""The number of feature outputs returned by the feature extractor."""
return 1
def get_model(self):
return self._network
def mobilenet_v2(channel_means, channel_stds, bgr_ordering):
"""The MobileNetV2 backbone for CenterNet."""
# We set 'is_training' to True for now.
network = mobilenetv2.mobilenet_v2(True, include_top=False)
return CenterNetMobileNetV2FeatureExtractor(
network,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Testing mobilenet_v2 feature extractor for CenterNet."""
import unittest
import numpy as np
import tensorflow.compat.v1 as tf
from object_detection.models import center_net_mobilenet_v2_feature_extractor
from object_detection.models.keras_models import mobilenet_v2
from object_detection.utils import test_case
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class CenterNetMobileNetV2FeatureExtractorTest(test_case.TestCase):
def test_center_net_mobilenet_v2_feature_extractor(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False)
model = center_net_mobilenet_v2_feature_extractor.CenterNetMobileNetV2FeatureExtractor(
net)
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 64))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment