Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3237c080
Commit
3237c080
authored
Oct 28, 2017
by
Vivek Rathod
Browse files
add NASnet feature extractor
parent
c839310b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
501 additions
and
0 deletions
+501
-0
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+2
-0
research/object_detection/builders/model_builder_test.py
research/object_detection/builders/model_builder_test.py
+68
-0
research/object_detection/models/BUILD
research/object_detection/models/BUILD
+23
-0
research/object_detection/models/faster_rcnn_nas_feature_extractor.py
...ect_detection/models/faster_rcnn_nas_feature_extractor.py
+299
-0
research/object_detection/models/faster_rcnn_nas_feature_extractor_test.py
...etection/models/faster_rcnn_nas_feature_extractor_test.py
+109
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
3237c080
...
...
@@ -47,6 +47,8 @@ SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
# A map of names to Faster R-CNN feature extractors.
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP
=
{
'faster_rcnn_nas'
:
frcnn_nas
.
FasterRCNNNASFeatureExtractor
,
'faster_rcnn_inception_resnet_v2'
:
frcnn_inc_res
.
FasterRCNNInceptionResnetV2FeatureExtractor
,
'faster_rcnn_inception_v2'
:
...
...
research/object_detection/builders/model_builder_test.py
View file @
3237c080
...
...
@@ -24,6 +24,7 @@ from object_detection.meta_architectures import rfcn_meta_arch
from
object_detection.meta_architectures
import
ssd_meta_arch
from
object_detection.models
import
faster_rcnn_inception_resnet_v2_feature_extractor
as
frcnn_inc_res
from
object_detection.models
import
faster_rcnn_inception_v2_feature_extractor
as
frcnn_inc_v2
from
object_detection.models
import
faster_rcnn_nas_feature_extractor
as
frcnn_nas
from
object_detection.models
import
faster_rcnn_resnet_v1_feature_extractor
as
frcnn_resnet_v1
from
object_detection.models.ssd_inception_v2_feature_extractor
import
SSDInceptionV2FeatureExtractor
from
object_detection.models.ssd_inception_v3_feature_extractor
import
SSDInceptionV3FeatureExtractor
...
...
@@ -412,6 +413,73 @@ class ModelBuilderTest(tf.test.TestCase):
model
=
model_builder
.
build
(
model_proto
,
is_training
=
True
)
self
.
assertAlmostEqual
(
model
.
_second_stage_mask_loss_weight
,
3.0
)
def
test_create_faster_rcnn_nas_model_from_config
(
self
):
model_text_proto
=
"""
faster_rcnn {
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_nas'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 17
maxpool_kernel_size: 1
maxpool_stride: 1
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto
=
model_pb2
.
DetectionModel
()
text_format
.
Merge
(
model_text_proto
,
model_proto
)
model
=
model_builder
.
build
(
model_proto
,
is_training
=
True
)
self
.
assertIsInstance
(
model
,
faster_rcnn_meta_arch
.
FasterRCNNMetaArch
)
self
.
assertIsInstance
(
model
.
_feature_extractor
,
frcnn_nas
.
FasterRCNNNASFeatureExtractor
)
def
test_create_faster_rcnn_inception_resnet_v2_model_from_config
(
self
):
model_text_proto
=
"""
faster_rcnn {
...
...
research/object_detection/models/BUILD
View file @
3237c080
...
...
@@ -135,6 +135,29 @@ py_test(
],
)
py_test
(
name
=
"faster_rcnn_nas_feature_extractor_test"
,
srcs
=
[
"faster_rcnn_nas_feature_extractor_test.py"
,
],
deps
=
[
":faster_rcnn_nas_feature_extractor"
,
"//tensorflow"
,
],
)
py_library
(
name
=
"faster_rcnn_nas_feature_extractor"
,
srcs
=
[
"faster_rcnn_nas_feature_extractor.py"
,
],
deps
=
[
"//tensorflow"
,
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch"
,
"//tensorflow_models/slim:nasnet"
,
],
)
py_library
(
name
=
"faster_rcnn_inception_resnet_v2_feature_extractor"
,
srcs
=
[
...
...
research/object_detection/models/faster_rcnn_nas_feature_extractor.py
0 → 100644
View file @
3237c080
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""NASNet Faster R-CNN implementation.
Learning Transferable Architectures for Scalable Image Recognition
Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le
https://arxiv.org/abs/1707.07012
"""
import
tensorflow
as
tf
from
object_detection.meta_architectures
import
faster_rcnn_meta_arch
from
nets.nasnet
import
nasnet
from
nets.nasnet
import
nasnet_utils
arg_scope
=
tf
.
contrib
.
framework
.
arg_scope
slim
=
tf
.
contrib
.
slim
# Note: This is largely a copy of _build_nasnet_base inside nasnet.py but
# with special edits to remove instantiation of the stem and the special
# ability to receive as input a pair of hidden states.
def
_build_nasnet_base
(
hidden_previous
,
hidden
,
normal_cell
,
reduction_cell
,
hparams
,
true_cell_num
,
start_cell_num
):
"""Constructs a NASNet image model."""
# Find where to place the reduction cells or stride normal cells
reduction_indices
=
nasnet_utils
.
calc_reduction_layers
(
hparams
.
num_cells
,
hparams
.
num_reduction_layers
)
# Note: The None is prepended to match the behavior of _imagenet_stem()
cell_outputs
=
[
None
,
hidden_previous
,
hidden
]
net
=
hidden
# NOTE: In the nasnet.py code, filter_scaling starts at 1.0. We instead
# start at 2.0 because 1 reduction cell has been created which would
# update the filter_scaling to 2.0.
filter_scaling
=
2.0
# Run the cells
for
cell_num
in
range
(
start_cell_num
,
hparams
.
num_cells
):
stride
=
1
if
hparams
.
skip_reduction_layer_input
:
prev_layer
=
cell_outputs
[
-
2
]
if
cell_num
in
reduction_indices
:
filter_scaling
*=
hparams
.
filter_scaling_rate
net
=
reduction_cell
(
net
,
scope
=
'reduction_cell_{}'
.
format
(
reduction_indices
.
index
(
cell_num
)),
filter_scaling
=
filter_scaling
,
stride
=
2
,
prev_layer
=
cell_outputs
[
-
2
],
cell_num
=
true_cell_num
)
true_cell_num
+=
1
cell_outputs
.
append
(
net
)
if
not
hparams
.
skip_reduction_layer_input
:
prev_layer
=
cell_outputs
[
-
2
]
net
=
normal_cell
(
net
,
scope
=
'cell_{}'
.
format
(
cell_num
),
filter_scaling
=
filter_scaling
,
stride
=
stride
,
prev_layer
=
prev_layer
,
cell_num
=
true_cell_num
)
true_cell_num
+=
1
cell_outputs
.
append
(
net
)
# Final nonlinearity.
# Note that we have dropped the final pooling, dropout and softmax layers
# from the default nasnet version.
with
tf
.
variable_scope
(
'final_layer'
):
net
=
tf
.
nn
.
relu
(
net
)
return
net
# TODO: Only fixed_shape_resizer is currently supported for NASNet
# featurization. The reason for this is that nasnet.py only supports
# inputs with fully known shapes. We need to update nasnet.py to handle
# shapes not known at compile time.
class
FasterRCNNNASFeatureExtractor
(
faster_rcnn_meta_arch
.
FasterRCNNFeatureExtractor
):
"""Faster R-CNN with NASNet-A feature extractor implementation."""
def
__init__
(
self
,
is_training
,
first_stage_features_stride
,
batch_norm_trainable
=
False
,
reuse_weights
=
None
,
weight_decay
=
0.0
):
"""Constructor.
Args:
is_training: See base class.
first_stage_features_stride: See base class.
batch_norm_trainable: See base class.
reuse_weights: See base class.
weight_decay: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 16.
"""
if
first_stage_features_stride
!=
16
:
raise
ValueError
(
'`first_stage_features_stride` must be 16.'
)
super
(
FasterRCNNNASFeatureExtractor
,
self
).
__init__
(
is_training
,
first_stage_features_stride
,
batch_norm_trainable
,
reuse_weights
,
weight_decay
)
def
preprocess
(
self
,
resized_inputs
):
"""Faster R-CNN with NAS preprocessing.
Maps pixel values to the range [-1, 1].
Args:
resized_inputs: A [batch, height_in, width_in, channels] float32 tensor
representing a batch of images with values between 0 and 255.0.
Returns:
preprocessed_inputs: A [batch, height_out, width_out, channels] float32
tensor representing a batch of images.
"""
return
(
2.0
/
255.0
)
*
resized_inputs
-
1.0
def
_extract_proposal_features
(
self
,
preprocessed_inputs
,
scope
):
"""Extracts first stage RPN features.
Extracts features using the first half of the NASNet network.
We construct the network in `align_feature_maps=True` mode, which means
that all VALID paddings in the network are changed to SAME padding so that
the feature maps are aligned.
Args:
preprocessed_inputs: A [batch, height, width, channels] float32 tensor
representing a batch of images.
scope: A scope name.
Returns:
rpn_feature_map: A tensor with shape [batch, height, width, depth]
Raises:
ValueError: If the created network is missing the required activation.
"""
del
scope
if
len
(
preprocessed_inputs
.
get_shape
().
as_list
())
!=
4
:
raise
ValueError
(
'`preprocessed_inputs` must be 4 dimensional, got a '
'tensor of shape %s'
%
preprocessed_inputs
.
get_shape
())
with
slim
.
arg_scope
(
nasnet
.
nasnet_large_arg_scope
()):
_
,
end_points
=
nasnet
.
build_nasnet_large
(
preprocessed_inputs
,
num_classes
=
None
,
is_training
=
self
.
_is_training
,
is_batchnorm_training
=
self
.
_train_batch_norm
,
final_endpoint
=
'Cell_11'
)
# Note that both 'Cell_10' and 'Cell_11' have equal depth = 2016.
rpn_feature_map
=
tf
.
concat
([
end_points
[
'Cell_10'
],
end_points
[
'Cell_11'
]],
3
)
# nasnet.py does not maintain the batch size in the first dimension.
# This work around permits us retaining the batch for below.
batch
=
preprocessed_inputs
.
get_shape
().
as_list
()[
0
]
shape_without_batch
=
rpn_feature_map
.
get_shape
().
as_list
()[
1
:]
rpn_feature_map_shape
=
[
batch
]
+
shape_without_batch
rpn_feature_map
.
set_shape
(
rpn_feature_map_shape
)
return
rpn_feature_map
def
_extract_box_classifier_features
(
self
,
proposal_feature_maps
,
scope
):
"""Extracts second stage box classifier features.
This function reconstructs the "second half" of the NASNet-A
network after the part defined in `_extract_proposal_features`.
Args:
proposal_feature_maps: A 4-D float tensor with shape
[batch_size * self.max_num_proposals, crop_height, crop_width, depth]
representing the feature map cropped to each proposal.
scope: A scope name.
Returns:
proposal_classifier_features: A 4-D float tensor with shape
[batch_size * self.max_num_proposals, height, width, depth]
representing box classifier features for each proposal.
"""
del
scope
# Note that we always feed into 2 layers of equal depth
# where the first N channels corresponds to previous hidden layer
# and the second N channels correspond to the final hidden layer.
hidden_previous
,
hidden
=
tf
.
split
(
proposal_feature_maps
,
2
,
axis
=
3
)
# Note that what follows is largely a copy of build_nasnet_large() within
# nasnet.py. We are copying to minimize code pollution in slim.
# pylint: disable=protected-access
hparams
=
nasnet
.
_large_imagenet_config
(
is_training
=
self
.
_is_training
)
# pylint: enable=protected-access
# Calculate the total number of cells in the network
# -- Add 2 for the reduction cells.
total_num_cells
=
hparams
.
num_cells
+
2
# -- And add 2 for the stem cells for ImageNet training.
total_num_cells
+=
2
normal_cell
=
nasnet_utils
.
NasNetANormalCell
(
hparams
.
num_conv_filters
,
hparams
.
drop_path_keep_prob
,
total_num_cells
,
hparams
.
total_training_steps
)
reduction_cell
=
nasnet_utils
.
NasNetAReductionCell
(
hparams
.
num_conv_filters
,
hparams
.
drop_path_keep_prob
,
total_num_cells
,
hparams
.
total_training_steps
)
with
arg_scope
([
slim
.
dropout
,
nasnet_utils
.
drop_path
],
is_training
=
self
.
_is_training
):
with
arg_scope
([
slim
.
batch_norm
],
is_training
=
self
.
_train_batch_norm
):
with
arg_scope
([
slim
.
avg_pool2d
,
slim
.
max_pool2d
,
slim
.
conv2d
,
slim
.
batch_norm
,
slim
.
separable_conv2d
,
nasnet_utils
.
factorized_reduction
,
nasnet_utils
.
global_avg_pool
,
nasnet_utils
.
get_channel_index
,
nasnet_utils
.
get_channel_dim
],
data_format
=
hparams
.
data_format
):
# This corresponds to the cell number just past 'Cell_11' used by
# by _extract_proposal_features().
start_cell_num
=
12
# Note that this number equals:
# start_cell_num + 2 stem cells + 1 reduction cell
true_cell_num
=
15
with
slim
.
arg_scope
(
nasnet
.
nasnet_large_arg_scope
()):
net
=
_build_nasnet_base
(
hidden_previous
,
hidden
,
normal_cell
=
normal_cell
,
reduction_cell
=
reduction_cell
,
hparams
=
hparams
,
true_cell_num
=
true_cell_num
,
start_cell_num
=
start_cell_num
)
proposal_classifier_features
=
net
return
proposal_classifier_features
def
restore_from_classification_checkpoint_fn
(
self
,
first_stage_feature_extractor_scope
,
second_stage_feature_extractor_scope
):
"""Returns a map of variables to load from a foreign checkpoint.
Note that this overrides the default implementation in
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
NASNet-A checkpoints.
Args:
first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor.
Returns:
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
"""
# Note that the NAS checkpoint only contains the moving average version of
# the Variables so we need to generate an appropriate dictionary mapping.
variables_to_restore
=
{}
for
variable
in
tf
.
global_variables
():
if
variable
.
op
.
name
.
startswith
(
first_stage_feature_extractor_scope
):
var_name
=
variable
.
op
.
name
.
replace
(
first_stage_feature_extractor_scope
+
'/'
,
''
)
var_name
+=
'/ExponentialMovingAverage'
variables_to_restore
[
var_name
]
=
variable
if
variable
.
op
.
name
.
startswith
(
second_stage_feature_extractor_scope
):
var_name
=
variable
.
op
.
name
.
replace
(
second_stage_feature_extractor_scope
+
'/'
,
''
)
var_name
+=
'/ExponentialMovingAverage'
variables_to_restore
[
var_name
]
=
variable
return
variables_to_restore
research/object_detection/models/faster_rcnn_nas_feature_extractor_test.py
0 → 100644
View file @
3237c080
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for models.faster_rcnn_nas_feature_extractor."""
import
tensorflow
as
tf
from
object_detection.models
import
faster_rcnn_nas_feature_extractor
as
frcnn_nas
class
FasterRcnnNASFeatureExtractorTest
(
tf
.
test
.
TestCase
):
def
_build_feature_extractor
(
self
,
first_stage_features_stride
):
return
frcnn_nas
.
FasterRCNNNASFeatureExtractor
(
is_training
=
False
,
first_stage_features_stride
=
first_stage_features_stride
,
batch_norm_trainable
=
False
,
reuse_weights
=
None
,
weight_decay
=
0.0
)
def
test_extract_proposal_features_returns_expected_size
(
self
):
feature_extractor
=
self
.
_build_feature_extractor
(
first_stage_features_stride
=
16
)
preprocessed_inputs
=
tf
.
random_uniform
(
[
1
,
299
,
299
,
3
],
maxval
=
255
,
dtype
=
tf
.
float32
)
rpn_feature_map
=
feature_extractor
.
extract_proposal_features
(
preprocessed_inputs
,
scope
=
'TestScope'
)
features_shape
=
tf
.
shape
(
rpn_feature_map
)
init_op
=
tf
.
global_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
features_shape_out
=
sess
.
run
(
features_shape
)
self
.
assertAllEqual
(
features_shape_out
,
[
1
,
19
,
19
,
4032
])
def
test_extract_proposal_features_input_size_224
(
self
):
feature_extractor
=
self
.
_build_feature_extractor
(
first_stage_features_stride
=
16
)
preprocessed_inputs
=
tf
.
random_uniform
(
[
1
,
224
,
224
,
3
],
maxval
=
255
,
dtype
=
tf
.
float32
)
rpn_feature_map
=
feature_extractor
.
extract_proposal_features
(
preprocessed_inputs
,
scope
=
'TestScope'
)
features_shape
=
tf
.
shape
(
rpn_feature_map
)
init_op
=
tf
.
global_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
features_shape_out
=
sess
.
run
(
features_shape
)
self
.
assertAllEqual
(
features_shape_out
,
[
1
,
14
,
14
,
4032
])
def
test_extract_proposal_features_input_size_112
(
self
):
feature_extractor
=
self
.
_build_feature_extractor
(
first_stage_features_stride
=
16
)
preprocessed_inputs
=
tf
.
random_uniform
(
[
1
,
112
,
112
,
3
],
maxval
=
255
,
dtype
=
tf
.
float32
)
rpn_feature_map
=
feature_extractor
.
extract_proposal_features
(
preprocessed_inputs
,
scope
=
'TestScope'
)
features_shape
=
tf
.
shape
(
rpn_feature_map
)
init_op
=
tf
.
global_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
features_shape_out
=
sess
.
run
(
features_shape
)
self
.
assertAllEqual
(
features_shape_out
,
[
1
,
7
,
7
,
4032
])
def
test_extract_proposal_features_dies_on_invalid_stride
(
self
):
with
self
.
assertRaises
(
ValueError
):
self
.
_build_feature_extractor
(
first_stage_features_stride
=
99
)
def
test_extract_proposal_features_dies_with_incorrect_rank_inputs
(
self
):
feature_extractor
=
self
.
_build_feature_extractor
(
first_stage_features_stride
=
16
)
preprocessed_inputs
=
tf
.
random_uniform
(
[
224
,
224
,
3
],
maxval
=
255
,
dtype
=
tf
.
float32
)
with
self
.
assertRaises
(
ValueError
):
feature_extractor
.
extract_proposal_features
(
preprocessed_inputs
,
scope
=
'TestScope'
)
def
test_extract_box_classifier_features_returns_expected_size
(
self
):
feature_extractor
=
self
.
_build_feature_extractor
(
first_stage_features_stride
=
16
)
proposal_feature_maps
=
tf
.
random_uniform
(
[
2
,
17
,
17
,
1088
],
maxval
=
255
,
dtype
=
tf
.
float32
)
proposal_classifier_features
=
(
feature_extractor
.
extract_box_classifier_features
(
proposal_feature_maps
,
scope
=
'TestScope'
))
features_shape
=
tf
.
shape
(
proposal_classifier_features
)
init_op
=
tf
.
global_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
init_op
)
features_shape_out
=
sess
.
run
(
features_shape
)
self
.
assertAllEqual
(
features_shape_out
,
[
2
,
9
,
9
,
4032
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment