Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d2c5bfac
Commit
d2c5bfac
authored
Mar 27, 2018
by
Zhichao Lu
Committed by
pkulzc
Apr 02, 2018
Browse files
Provide option to perform in-place batch norm updates for ssd feature extractors.
PiperOrigin-RevId: 190688309
parent
3956d90e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
150 additions
and
32 deletions
+150
-32
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+25
-6
research/object_detection/builders/model_builder_test.py
research/object_detection/builders/model_builder_test.py
+2
-0
research/object_detection/meta_architectures/ssd_meta_arch.py
...arch/object_detection/meta_architectures/ssd_meta_arch.py
+30
-2
research/object_detection/models/embedded_ssd_mobilenet_v1_feature_extractor.py
...ion/models/embedded_ssd_mobilenet_v1_feature_extractor.py
+9
-3
research/object_detection/models/ssd_inception_v2_feature_extractor.py
...ct_detection/models/ssd_inception_v2_feature_extractor.py
+9
-3
research/object_detection/models/ssd_inception_v3_feature_extractor.py
...ct_detection/models/ssd_inception_v3_feature_extractor.py
+9
-3
research/object_detection/models/ssd_mobilenet_v1_feature_extractor.py
...ct_detection/models/ssd_mobilenet_v1_feature_extractor.py
+9
-3
research/object_detection/models/ssd_mobilenet_v2_feature_extractor.py
...ct_detection/models/ssd_mobilenet_v2_feature_extractor.py
+9
-3
research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor.py
...t_detection/models/ssd_resnet_v1_fpn_feature_extractor.py
+36
-9
research/object_detection/protos/faster_rcnn.proto
research/object_detection/protos/faster_rcnn.proto
+6
-0
research/object_detection/protos/ssd.proto
research/object_detection/protos/ssd.proto
+6
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
d2c5bfac
...
@@ -95,13 +95,19 @@ def build(model_config, is_training, add_summaries=True):
...
@@ -95,13 +95,19 @@ def build(model_config, is_training, add_summaries=True):
def
_build_ssd_feature_extractor
(
feature_extractor_config
,
is_training
,
def
_build_ssd_feature_extractor
(
feature_extractor_config
,
is_training
,
reuse_weights
=
None
):
reuse_weights
=
None
,
inplace_batchnorm_update
=
False
):
"""Builds a ssd_meta_arch.SSDFeatureExtractor based on config.
"""Builds a ssd_meta_arch.SSDFeatureExtractor based on config.
Args:
Args:
feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
is_training: True if this feature extractor is being built for training.
is_training: True if this feature extractor is being built for training.
reuse_weights: if the feature extractor should reuse weights.
reuse_weights: if the feature extractor should reuse weights.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs. When
this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Returns:
Returns:
ssd_meta_arch.SSDFeatureExtractor based on config.
ssd_meta_arch.SSDFeatureExtractor based on config.
...
@@ -126,7 +132,8 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
...
@@ -126,7 +132,8 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
return
feature_extractor_class
(
is_training
,
depth_multiplier
,
min_depth
,
return
feature_extractor_class
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
pad_to_multiple
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
,
use_depthwise
)
use_explicit_padding
,
use_depthwise
,
inplace_batchnorm_update
)
def
_build_ssd_model
(
ssd_config
,
is_training
,
add_summaries
):
def
_build_ssd_model
(
ssd_config
,
is_training
,
add_summaries
):
...
@@ -140,6 +147,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
...
@@ -140,6 +147,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
Returns:
Returns:
SSDMetaArch based on the config.
SSDMetaArch based on the config.
Raises:
Raises:
ValueError: If ssd_config.type is not recognized (i.e. not registered in
ValueError: If ssd_config.type is not recognized (i.e. not registered in
model_class_map).
model_class_map).
...
@@ -147,8 +155,9 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
...
@@ -147,8 +155,9 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
num_classes
=
ssd_config
.
num_classes
num_classes
=
ssd_config
.
num_classes
# Feature extractor
# Feature extractor
feature_extractor
=
_build_ssd_feature_extractor
(
ssd_config
.
feature_extractor
,
feature_extractor
=
_build_ssd_feature_extractor
(
is_training
)
ssd_config
.
feature_extractor
,
is_training
,
ssd_config
.
inplace_batchnorm_update
)
box_coder
=
box_coder_builder
.
build
(
ssd_config
.
box_coder
)
box_coder
=
box_coder_builder
.
build
(
ssd_config
.
box_coder
)
matcher
=
matcher_builder
.
build
(
ssd_config
.
matcher
)
matcher
=
matcher_builder
.
build
(
ssd_config
.
matcher
)
...
@@ -194,7 +203,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
...
@@ -194,7 +203,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
def
_build_faster_rcnn_feature_extractor
(
def
_build_faster_rcnn_feature_extractor
(
feature_extractor_config
,
is_training
,
reuse_weights
=
None
):
feature_extractor_config
,
is_training
,
reuse_weights
=
None
,
inplace_batchnorm_update
=
False
):
"""Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
"""Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
Args:
Args:
...
@@ -202,6 +212,11 @@ def _build_faster_rcnn_feature_extractor(
...
@@ -202,6 +212,11 @@ def _build_faster_rcnn_feature_extractor(
faster_rcnn.proto.
faster_rcnn.proto.
is_training: True if this feature extractor is being built for training.
is_training: True if this feature extractor is being built for training.
reuse_weights: if the feature extractor should reuse weights.
reuse_weights: if the feature extractor should reuse weights.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs. When
this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Returns:
Returns:
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
...
@@ -209,6 +224,8 @@ def _build_faster_rcnn_feature_extractor(
...
@@ -209,6 +224,8 @@ def _build_faster_rcnn_feature_extractor(
Raises:
Raises:
ValueError: On invalid feature extractor type.
ValueError: On invalid feature extractor type.
"""
"""
if
inplace_batchnorm_update
:
raise
ValueError
(
'inplace batchnorm updates not supported.'
)
feature_type
=
feature_extractor_config
.
type
feature_type
=
feature_extractor_config
.
type
first_stage_features_stride
=
(
first_stage_features_stride
=
(
feature_extractor_config
.
first_stage_features_stride
)
feature_extractor_config
.
first_stage_features_stride
)
...
@@ -238,6 +255,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
...
@@ -238,6 +255,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
Returns:
Returns:
FasterRCNNMetaArch based on the config.
FasterRCNNMetaArch based on the config.
Raises:
Raises:
ValueError: If frcnn_config.type is not recognized (i.e. not registered in
ValueError: If frcnn_config.type is not recognized (i.e. not registered in
model_class_map).
model_class_map).
...
@@ -246,7 +264,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
...
@@ -246,7 +264,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
image_resizer_fn
=
image_resizer_builder
.
build
(
frcnn_config
.
image_resizer
)
image_resizer_fn
=
image_resizer_builder
.
build
(
frcnn_config
.
image_resizer
)
feature_extractor
=
_build_faster_rcnn_feature_extractor
(
feature_extractor
=
_build_faster_rcnn_feature_extractor
(
frcnn_config
.
feature_extractor
,
is_training
)
frcnn_config
.
feature_extractor
,
is_training
,
frcnn_config
.
inplace_batchnorm_update
)
number_of_stages
=
frcnn_config
.
number_of_stages
number_of_stages
=
frcnn_config
.
number_of_stages
first_stage_anchor_generator
=
anchor_generator_builder
.
build
(
first_stage_anchor_generator
=
anchor_generator_builder
.
build
(
...
...
research/object_detection/builders/model_builder_test.py
View file @
d2c5bfac
...
@@ -297,6 +297,7 @@ class ModelBuilderTest(tf.test.TestCase):
...
@@ -297,6 +297,7 @@ class ModelBuilderTest(tf.test.TestCase):
def
test_create_ssd_mobilenet_v1_model_from_config
(
self
):
def
test_create_ssd_mobilenet_v1_model_from_config
(
self
):
model_text_proto
=
"""
model_text_proto
=
"""
ssd {
ssd {
inplace_batchnorm_update: true
feature_extractor {
feature_extractor {
type: 'ssd_mobilenet_v1'
type: 'ssd_mobilenet_v1'
conv_hyperparams {
conv_hyperparams {
...
@@ -519,6 +520,7 @@ class ModelBuilderTest(tf.test.TestCase):
...
@@ -519,6 +520,7 @@ class ModelBuilderTest(tf.test.TestCase):
def
test_create_faster_rcnn_resnet_v1_models_from_config
(
self
):
def
test_create_faster_rcnn_resnet_v1_models_from_config
(
self
):
model_text_proto
=
"""
model_text_proto
=
"""
faster_rcnn {
faster_rcnn {
inplace_batchnorm_update: true
num_classes: 3
num_classes: 3
image_resizer {
image_resizer {
keep_aspect_ratio_resizer {
keep_aspect_ratio_resizer {
...
...
research/object_detection/meta_architectures/ssd_meta_arch.py
View file @
d2c5bfac
...
@@ -46,7 +46,8 @@ class SSDFeatureExtractor(object):
...
@@ -46,7 +46,8 @@ class SSDFeatureExtractor(object):
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""Constructor.
"""Constructor.
Args:
Args:
...
@@ -64,6 +65,10 @@ class SSDFeatureExtractor(object):
...
@@ -64,6 +65,10 @@ class SSDFeatureExtractor(object):
use_explicit_padding: Whether to use explicit padding when extracting
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False.
features. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
inplace_batchnorm_update: Whether to update batch norm moving average
values inplace. When this is false train op must add a control
dependency on tf.graphkeys.UPDATE_OPS collection in order to update
batch norm statistics.
"""
"""
self
.
_is_training
=
is_training
self
.
_is_training
=
is_training
self
.
_depth_multiplier
=
depth_multiplier
self
.
_depth_multiplier
=
depth_multiplier
...
@@ -71,6 +76,7 @@ class SSDFeatureExtractor(object):
...
@@ -71,6 +76,7 @@ class SSDFeatureExtractor(object):
self
.
_pad_to_multiple
=
pad_to_multiple
self
.
_pad_to_multiple
=
pad_to_multiple
self
.
_conv_hyperparams
=
conv_hyperparams
self
.
_conv_hyperparams
=
conv_hyperparams
self
.
_batch_norm_trainable
=
batch_norm_trainable
self
.
_batch_norm_trainable
=
batch_norm_trainable
self
.
_inplace_batchnorm_update
=
inplace_batchnorm_update
self
.
_reuse_weights
=
reuse_weights
self
.
_reuse_weights
=
reuse_weights
self
.
_use_explicit_padding
=
use_explicit_padding
self
.
_use_explicit_padding
=
use_explicit_padding
self
.
_use_depthwise
=
use_depthwise
self
.
_use_depthwise
=
use_depthwise
...
@@ -108,7 +114,29 @@ class SSDFeatureExtractor(object):
...
@@ -108,7 +114,29 @@ class SSDFeatureExtractor(object):
feature_maps: a list of tensors where the ith tensor has shape
feature_maps: a list of tensors where the ith tensor has shape
[batch, height_i, width_i, depth_i]
[batch, height_i, width_i, depth_i]
"""
"""
pass
batchnorm_updates_collections
=
(
None
if
self
.
_inplace_batchnorm_update
else
tf
.
GraphKeys
.
UPDATE_OPS
)
with
slim
.
arg_scope
([
slim
.
batch_norm
],
updates_collections
=
batchnorm_updates_collections
):
return
self
.
_extract_features
(
preprocessed_inputs
)
@
abstractmethod
def
_extract_features
(
self
,
preprocessed_inputs
):
"""Extracts features from preprocessed inputs.
This function is responsible for extracting feature maps from preprocessed
images.
Args:
preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
Returns:
feature_maps: a list of tensors where the ith tensor has shape
[batch, height_i, width_i, depth_i]
"""
raise
NotImplementedError
class
SSDMetaArch
(
model
.
DetectionModel
):
class
SSDMetaArch
(
model
.
DetectionModel
):
...
...
research/object_detection/models/embedded_ssd_mobilenet_v1_feature_extractor.py
View file @
d2c5bfac
...
@@ -53,7 +53,8 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
...
@@ -53,7 +53,8 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""MobileNetV1 Feature Extractor for Embedded-friendly SSD Models.
"""MobileNetV1 Feature Extractor for Embedded-friendly SSD Models.
Args:
Args:
...
@@ -71,6 +72,11 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
...
@@ -71,6 +72,11 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
use_explicit_padding: Whether to use explicit padding when extracting
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False.
features. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Raises:
Raises:
ValueError: upon invalid `pad_to_multiple` values.
ValueError: upon invalid `pad_to_multiple` values.
...
@@ -82,9 +88,9 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
...
@@ -82,9 +88,9 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
super
(
EmbeddedSSDMobileNetV1FeatureExtractor
,
self
).
__init__
(
super
(
EmbeddedSSDMobileNetV1FeatureExtractor
,
self
).
__init__
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
,
use_depthwise
)
use_explicit_padding
,
use_depthwise
,
inplace_batchnorm_update
)
def
extract_features
(
self
,
preprocessed_inputs
):
def
_
extract_features
(
self
,
preprocessed_inputs
):
"""Extract features from preprocessed inputs.
"""Extract features from preprocessed inputs.
Args:
Args:
...
...
research/object_detection/models/ssd_inception_v2_feature_extractor.py
View file @
d2c5bfac
...
@@ -37,7 +37,8 @@ class SSDInceptionV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -37,7 +37,8 @@ class SSDInceptionV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""InceptionV2 Feature Extractor for SSD Models.
"""InceptionV2 Feature Extractor for SSD Models.
Args:
Args:
...
@@ -55,11 +56,16 @@ class SSDInceptionV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -55,11 +56,16 @@ class SSDInceptionV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
use_explicit_padding: Whether to use explicit padding when extracting
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False.
features. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
"""
"""
super
(
SSDInceptionV2FeatureExtractor
,
self
).
__init__
(
super
(
SSDInceptionV2FeatureExtractor
,
self
).
__init__
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
,
use_depthwise
)
use_explicit_padding
,
use_depthwise
,
inplace_batchnorm_update
)
def
preprocess
(
self
,
resized_inputs
):
def
preprocess
(
self
,
resized_inputs
):
"""SSD preprocessing.
"""SSD preprocessing.
...
@@ -76,7 +82,7 @@ class SSDInceptionV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -76,7 +82,7 @@ class SSDInceptionV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
"""
"""
return
(
2.0
/
255.0
)
*
resized_inputs
-
1.0
return
(
2.0
/
255.0
)
*
resized_inputs
-
1.0
def
extract_features
(
self
,
preprocessed_inputs
):
def
_
extract_features
(
self
,
preprocessed_inputs
):
"""Extract features from preprocessed inputs.
"""Extract features from preprocessed inputs.
Args:
Args:
...
...
research/object_detection/models/ssd_inception_v3_feature_extractor.py
View file @
d2c5bfac
...
@@ -37,7 +37,8 @@ class SSDInceptionV3FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -37,7 +37,8 @@ class SSDInceptionV3FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""InceptionV3 Feature Extractor for SSD Models.
"""InceptionV3 Feature Extractor for SSD Models.
Args:
Args:
...
@@ -55,11 +56,16 @@ class SSDInceptionV3FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -55,11 +56,16 @@ class SSDInceptionV3FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
use_explicit_padding: Whether to use explicit padding when extracting
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False.
features. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
"""
"""
super
(
SSDInceptionV3FeatureExtractor
,
self
).
__init__
(
super
(
SSDInceptionV3FeatureExtractor
,
self
).
__init__
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
,
use_depthwise
)
use_explicit_padding
,
use_depthwise
,
inplace_batchnorm_update
)
def
preprocess
(
self
,
resized_inputs
):
def
preprocess
(
self
,
resized_inputs
):
"""SSD preprocessing.
"""SSD preprocessing.
...
@@ -76,7 +82,7 @@ class SSDInceptionV3FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -76,7 +82,7 @@ class SSDInceptionV3FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
"""
"""
return
(
2.0
/
255.0
)
*
resized_inputs
-
1.0
return
(
2.0
/
255.0
)
*
resized_inputs
-
1.0
def
extract_features
(
self
,
preprocessed_inputs
):
def
_
extract_features
(
self
,
preprocessed_inputs
):
"""Extract features from preprocessed inputs.
"""Extract features from preprocessed inputs.
Args:
Args:
...
...
research/object_detection/models/ssd_mobilenet_v1_feature_extractor.py
View file @
d2c5bfac
...
@@ -38,7 +38,8 @@ class SSDMobileNetV1FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -38,7 +38,8 @@ class SSDMobileNetV1FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""MobileNetV1 Feature Extractor for SSD Models.
"""MobileNetV1 Feature Extractor for SSD Models.
Args:
Args:
...
@@ -57,11 +58,16 @@ class SSDMobileNetV1FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -57,11 +58,16 @@ class SSDMobileNetV1FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
inputs so that the output dimensions are the same as if 'SAME' padding
inputs so that the output dimensions are the same as if 'SAME' padding
were used.
were used.
use_depthwise: Whether to use depthwise convolutions. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
"""
"""
super
(
SSDMobileNetV1FeatureExtractor
,
self
).
__init__
(
super
(
SSDMobileNetV1FeatureExtractor
,
self
).
__init__
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
,
use_depthwise
)
use_explicit_padding
,
use_depthwise
,
inplace_batchnorm_update
)
def
preprocess
(
self
,
resized_inputs
):
def
preprocess
(
self
,
resized_inputs
):
"""SSD preprocessing.
"""SSD preprocessing.
...
@@ -78,7 +84,7 @@ class SSDMobileNetV1FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -78,7 +84,7 @@ class SSDMobileNetV1FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
"""
"""
return
(
2.0
/
255.0
)
*
resized_inputs
-
1.0
return
(
2.0
/
255.0
)
*
resized_inputs
-
1.0
def
extract_features
(
self
,
preprocessed_inputs
):
def
_
extract_features
(
self
,
preprocessed_inputs
):
"""Extract features from preprocessed inputs.
"""Extract features from preprocessed inputs.
Args:
Args:
...
...
research/object_detection/models/ssd_mobilenet_v2_feature_extractor.py
View file @
d2c5bfac
...
@@ -39,7 +39,8 @@ class SSDMobileNetV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -39,7 +39,8 @@ class SSDMobileNetV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""MobileNetV2 Feature Extractor for SSD Models.
"""MobileNetV2 Feature Extractor for SSD Models.
Mobilenet v2 (experimental), designed by sandler@. More details can be found
Mobilenet v2 (experimental), designed by sandler@. More details can be found
...
@@ -60,11 +61,16 @@ class SSDMobileNetV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -60,11 +61,16 @@ class SSDMobileNetV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
use_explicit_padding: Whether to use explicit padding when extracting
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False.
features. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
"""
"""
super
(
SSDMobileNetV2FeatureExtractor
,
self
).
__init__
(
super
(
SSDMobileNetV2FeatureExtractor
,
self
).
__init__
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
,
use_depthwise
)
use_explicit_padding
,
use_depthwise
,
inplace_batchnorm_update
)
def
preprocess
(
self
,
resized_inputs
):
def
preprocess
(
self
,
resized_inputs
):
"""SSD preprocessing.
"""SSD preprocessing.
...
@@ -81,7 +87,7 @@ class SSDMobileNetV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -81,7 +87,7 @@ class SSDMobileNetV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
"""
"""
return
(
2.0
/
255.0
)
*
resized_inputs
-
1.0
return
(
2.0
/
255.0
)
*
resized_inputs
-
1.0
def
extract_features
(
self
,
preprocessed_inputs
):
def
_
extract_features
(
self
,
preprocessed_inputs
):
"""Extract features from preprocessed inputs.
"""Extract features from preprocessed inputs.
Args:
Args:
...
...
research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor.py
View file @
d2c5bfac
...
@@ -43,7 +43,8 @@ class _SSDResnetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -43,7 +43,8 @@ class _SSDResnetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""SSD FPN feature extractor based on Resnet v1 architecture.
"""SSD FPN feature extractor based on Resnet v1 architecture.
Args:
Args:
...
@@ -66,6 +67,11 @@ class _SSDResnetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -66,6 +67,11 @@ class _SSDResnetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
use_explicit_padding: Whether to use explicit padding when extracting
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False. UNUSED currently.
features. Default is False. UNUSED currently.
use_depthwise: Whether to use depthwise convolutions. UNUSED currently.
use_depthwise: Whether to use depthwise convolutions. UNUSED currently.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Raises:
Raises:
ValueError: On supplying invalid arguments for unused arguments.
ValueError: On supplying invalid arguments for unused arguments.
...
@@ -73,7 +79,7 @@ class _SSDResnetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -73,7 +79,7 @@ class _SSDResnetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
super
(
_SSDResnetV1FpnFeatureExtractor
,
self
).
__init__
(
super
(
_SSDResnetV1FpnFeatureExtractor
,
self
).
__init__
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
conv_hyperparams
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
)
use_explicit_padding
,
inplace_batchnorm_update
)
if
self
.
_depth_multiplier
!=
1.0
:
if
self
.
_depth_multiplier
!=
1.0
:
raise
ValueError
(
'Only depth 1.0 is supported, found: {}'
.
raise
ValueError
(
'Only depth 1.0 is supported, found: {}'
.
format
(
self
.
_depth_multiplier
))
format
(
self
.
_depth_multiplier
))
...
@@ -110,7 +116,7 @@ class _SSDResnetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
...
@@ -110,7 +116,7 @@ class _SSDResnetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
filtered_image_features
[
feature_name
]
=
feature
filtered_image_features
[
feature_name
]
=
feature
return
filtered_image_features
return
filtered_image_features
def
extract_features
(
self
,
preprocessed_inputs
):
def
_
extract_features
(
self
,
preprocessed_inputs
):
"""Extract features from preprocessed inputs.
"""Extract features from preprocessed inputs.
Args:
Args:
...
@@ -176,7 +182,8 @@ class SSDResnet50V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
...
@@ -176,7 +182,8 @@ class SSDResnet50V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""Resnet50 v1 FPN Feature Extractor for SSD Models.
"""Resnet50 v1 FPN Feature Extractor for SSD Models.
Args:
Args:
...
@@ -194,11 +201,17 @@ class SSDResnet50V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
...
@@ -194,11 +201,17 @@ class SSDResnet50V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
use_explicit_padding: Whether to use explicit padding when extracting
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False. UNUSED currently.
features. Default is False. UNUSED currently.
use_depthwise: Whether to use depthwise convolutions. UNUSED currently.
use_depthwise: Whether to use depthwise convolutions. UNUSED currently.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
"""
"""
super
(
SSDResnet50V1FpnFeatureExtractor
,
self
).
__init__
(
super
(
SSDResnet50V1FpnFeatureExtractor
,
self
).
__init__
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
resnet_v1
.
resnet_v1_50
,
'resnet_v1_50'
,
'fpn'
,
conv_hyperparams
,
resnet_v1
.
resnet_v1_50
,
'resnet_v1_50'
,
'fpn'
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
)
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
,
inplace_batchnorm_update
)
class
SSDResnet101V1FpnFeatureExtractor
(
_SSDResnetV1FpnFeatureExtractor
):
class
SSDResnet101V1FpnFeatureExtractor
(
_SSDResnetV1FpnFeatureExtractor
):
...
@@ -212,7 +225,8 @@ class SSDResnet101V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
...
@@ -212,7 +225,8 @@ class SSDResnet101V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""Resnet101 v1 FPN Feature Extractor for SSD Models.
"""Resnet101 v1 FPN Feature Extractor for SSD Models.
Args:
Args:
...
@@ -230,11 +244,17 @@ class SSDResnet101V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
...
@@ -230,11 +244,17 @@ class SSDResnet101V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
use_explicit_padding: Whether to use explicit padding when extracting
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False. UNUSED currently.
features. Default is False. UNUSED currently.
use_depthwise: Whether to use depthwise convolutions. UNUSED currently.
use_depthwise: Whether to use depthwise convolutions. UNUSED currently.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
"""
"""
super
(
SSDResnet101V1FpnFeatureExtractor
,
self
).
__init__
(
super
(
SSDResnet101V1FpnFeatureExtractor
,
self
).
__init__
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
resnet_v1
.
resnet_v1_101
,
'resnet_v1_101'
,
'fpn'
,
conv_hyperparams
,
resnet_v1
.
resnet_v1_101
,
'resnet_v1_101'
,
'fpn'
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
)
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
,
inplace_batchnorm_update
)
class
SSDResnet152V1FpnFeatureExtractor
(
_SSDResnetV1FpnFeatureExtractor
):
class
SSDResnet152V1FpnFeatureExtractor
(
_SSDResnetV1FpnFeatureExtractor
):
...
@@ -248,7 +268,8 @@ class SSDResnet152V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
...
@@ -248,7 +268,8 @@ class SSDResnet152V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
batch_norm_trainable
=
True
,
batch_norm_trainable
=
True
,
reuse_weights
=
None
,
reuse_weights
=
None
,
use_explicit_padding
=
False
,
use_explicit_padding
=
False
,
use_depthwise
=
False
):
use_depthwise
=
False
,
inplace_batchnorm_update
=
False
):
"""Resnet152 v1 FPN Feature Extractor for SSD Models.
"""Resnet152 v1 FPN Feature Extractor for SSD Models.
Args:
Args:
...
@@ -266,8 +287,14 @@ class SSDResnet152V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
...
@@ -266,8 +287,14 @@ class SSDResnet152V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
use_explicit_padding: Whether to use explicit padding when extracting
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False. UNUSED currently.
features. Default is False. UNUSED currently.
use_depthwise: Whether to use depthwise convolutions. UNUSED currently.
use_depthwise: Whether to use depthwise convolutions. UNUSED currently.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
"""
"""
super
(
SSDResnet152V1FpnFeatureExtractor
,
self
).
__init__
(
super
(
SSDResnet152V1FpnFeatureExtractor
,
self
).
__init__
(
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
is_training
,
depth_multiplier
,
min_depth
,
pad_to_multiple
,
conv_hyperparams
,
resnet_v1
.
resnet_v1_152
,
'resnet_v1_152'
,
'fpn'
,
conv_hyperparams
,
resnet_v1
.
resnet_v1_152
,
'resnet_v1_152'
,
'fpn'
,
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
)
batch_norm_trainable
,
reuse_weights
,
use_explicit_padding
,
inplace_batchnorm_update
)
research/object_detection/protos/faster_rcnn.proto
View file @
d2c5bfac
...
@@ -131,6 +131,12 @@ message FasterRcnn {
...
@@ -131,6 +131,12 @@ message FasterRcnn {
// to use sigmoid loss and enable merge_multiple_label_boxes.
// to use sigmoid loss and enable merge_multiple_label_boxes.
// If not specified, Softmax loss is used as default.
// If not specified, Softmax loss is used as default.
optional
ClassificationLoss
second_stage_classification_loss
=
29
;
optional
ClassificationLoss
second_stage_classification_loss
=
29
;
// Whether to update batch_norm inplace during training. This is required
// for batch norm to work correctly on TPUs. When this is false, user must add
// a control dependency on tf.GraphKeys.UPDATE_OPS for train/loss op in order
// to update the batch norm moving average parameters.
optional
bool
inplace_batchnorm_update
=
30
[
default
=
false
];
}
}
...
...
research/object_detection/protos/ssd.proto
View file @
d2c5bfac
...
@@ -59,6 +59,12 @@ message Ssd {
...
@@ -59,6 +59,12 @@ message Ssd {
// Loss configuration for training.
// Loss configuration for training.
optional
Loss
loss
=
11
;
optional
Loss
loss
=
11
;
// Whether to update batch_norm inplace during training. This is required
// for batch norm to work correctly on TPUs. When this is false, user must add
// a control dependency on tf.GraphKeys.UPDATE_OPS for train/loss op in order
// to update the batch norm moving average parameters.
optional
bool
inplace_batchnorm_update
=
15
[
default
=
false
];
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment