Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5bb827c3
Commit
5bb827c3
authored
Mar 23, 2021
by
Austin Myers
Committed by
TF Object Detection Team
Mar 23, 2021
Browse files
Add 'weight_decay' and 'tpu' options to EfficientNet backbone overrides.
PiperOrigin-RevId: 364671967
parent
adff6ed3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
85 additions
and
8 deletions
+85
-8
research/object_detection/builders/hyperparams_builder.py
research/object_detection/builders/hyperparams_builder.py
+17
-0
research/object_detection/builders/hyperparams_builder_test.py
...rch/object_detection/builders/hyperparams_builder_test.py
+55
-0
research/object_detection/models/ssd_efficientnet_bifpn_feature_extractor.py
...ection/models/ssd_efficientnet_bifpn_feature_extractor.py
+13
-8
No files found.
research/object_detection/builders/hyperparams_builder.py
View file @
5bb827c3
...
@@ -90,6 +90,9 @@ class KerasLayerHyperparams(object):
...
@@ -90,6 +90,9 @@ class KerasLayerHyperparams(object):
def
use_batch_norm
(
self
):
def
use_batch_norm
(
self
):
return
self
.
_batch_norm_params
is
not
None
return
self
.
_batch_norm_params
is
not
None
def
use_sync_batch_norm
(
self
):
return
self
.
_use_sync_batch_norm
def
force_use_bias
(
self
):
def
force_use_bias
(
self
):
return
self
.
_force_use_bias
return
self
.
_force_use_bias
...
@@ -165,6 +168,20 @@ class KerasLayerHyperparams(object):
...
@@ -165,6 +168,20 @@ class KerasLayerHyperparams(object):
else
:
else
:
return
tf
.
keras
.
layers
.
Lambda
(
tf
.
identity
,
name
=
name
)
return
tf
.
keras
.
layers
.
Lambda
(
tf
.
identity
,
name
=
name
)
def
get_regularizer_weight
(
self
):
"""Returns the l1 or l2 regularizer weight.
Returns: A float value corresponding to the l1 or l2 regularization weight,
or None if neither l1 or l2 regularization is defined.
"""
regularizer
=
self
.
_op_params
[
'kernel_regularizer'
]
if
hasattr
(
regularizer
,
'l1'
):
return
regularizer
.
l1
elif
hasattr
(
regularizer
,
'l2'
):
return
regularizer
.
l2
else
:
return
None
def
params
(
self
,
include_activation
=
False
,
**
overrides
):
def
params
(
self
,
include_activation
=
False
,
**
overrides
):
"""Returns a dict containing the layer construction hyperparameters to use.
"""Returns a dict containing the layer construction hyperparameters to use.
...
...
research/object_detection/builders/hyperparams_builder_test.py
View file @
5bb827c3
...
@@ -580,6 +580,61 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
...
@@ -580,6 +580,61 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
result
=
regularizer
(
tf
.
constant
(
weights
)).
numpy
()
result
=
regularizer
(
tf
.
constant
(
weights
)).
numpy
()
self
.
assertAllClose
(
np
.
power
(
weights
,
2
).
sum
()
/
2.0
*
0.42
,
result
)
self
.
assertAllClose
(
np
.
power
(
weights
,
2
).
sum
()
/
2.0
*
0.42
,
result
)
def
test_return_l1_regularizer_weight_keras
(
self
):
conv_hyperparams_text_proto
=
"""
regularizer {
l1_regularizer {
weight: 0.5
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto
=
hyperparams_pb2
.
Hyperparams
()
text_format
.
Parse
(
conv_hyperparams_text_proto
,
conv_hyperparams_proto
)
keras_config
=
hyperparams_builder
.
KerasLayerHyperparams
(
conv_hyperparams_proto
)
regularizer_weight
=
keras_config
.
get_regularizer_weight
()
self
.
assertAlmostEqual
(
regularizer_weight
,
0.5
)
def
test_return_l2_regularizer_weight_keras
(
self
):
conv_hyperparams_text_proto
=
"""
regularizer {
l2_regularizer {
weight: 0.5
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto
=
hyperparams_pb2
.
Hyperparams
()
text_format
.
Parse
(
conv_hyperparams_text_proto
,
conv_hyperparams_proto
)
keras_config
=
hyperparams_builder
.
KerasLayerHyperparams
(
conv_hyperparams_proto
)
regularizer_weight
=
keras_config
.
get_regularizer_weight
()
self
.
assertAlmostEqual
(
regularizer_weight
,
0.25
)
def
test_return_undefined_regularizer_weight_keras
(
self
):
conv_hyperparams_text_proto
=
"""
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto
=
hyperparams_pb2
.
Hyperparams
()
text_format
.
Parse
(
conv_hyperparams_text_proto
,
conv_hyperparams_proto
)
keras_config
=
hyperparams_builder
.
KerasLayerHyperparams
(
conv_hyperparams_proto
)
regularizer_weight
=
keras_config
.
get_regularizer_weight
()
self
.
assertIsNone
(
regularizer_weight
)
def
test_return_non_default_batch_norm_params_keras
(
def
test_return_non_default_batch_norm_params_keras
(
self
):
self
):
conv_hyperparams_text_proto
=
"""
conv_hyperparams_text_proto
=
"""
...
...
research/object_detection/models/ssd_efficientnet_bifpn_feature_extractor.py
View file @
5bb827c3
...
@@ -23,6 +23,7 @@ from six.moves import range
...
@@ -23,6 +23,7 @@ from six.moves import range
from
six.moves
import
zip
from
six.moves
import
zip
import
tensorflow.compat.v2
as
tf
import
tensorflow.compat.v2
as
tf
from
tensorflow.python.keras
import
backend
as
keras_backend
from
object_detection.meta_architectures
import
ssd_meta_arch
from
object_detection.meta_architectures
import
ssd_meta_arch
from
object_detection.models
import
bidirectional_feature_pyramid_generators
as
bifpn_generators
from
object_detection.models
import
bidirectional_feature_pyramid_generators
as
bifpn_generators
from
object_detection.utils
import
ops
from
object_detection.utils
import
ops
...
@@ -103,9 +104,10 @@ class SSDEfficientNetBiFPNKerasFeatureExtractor(
...
@@ -103,9 +104,10 @@ class SSDEfficientNetBiFPNKerasFeatureExtractor(
use_depthwise: unsupported by EfficientNetBiFPN, since BiFPN uses regular
use_depthwise: unsupported by EfficientNetBiFPN, since BiFPN uses regular
convolutions when inputs to a node have a differing number of channels,
convolutions when inputs to a node have a differing number of channels,
and use separable convolutions after combine operations.
and use separable convolutions after combine operations.
override_base_feature_extractor_hyperparams: unsupported. Whether to
override_base_feature_extractor_hyperparams: Whether to override the
override hyperparameters of the base feature extractor with the one from
efficientnet backbone's default weight decay with the weight decay
`conv_hyperparams`.
defined by `conv_hyperparams`. Note, only overriding of weight decay is
currently supported.
name: a string name scope to assign to the model. If 'None', Keras will
name: a string name scope to assign to the model. If 'None', Keras will
auto-generate one from the class name.
auto-generate one from the class name.
"""
"""
...
@@ -129,9 +131,6 @@ class SSDEfficientNetBiFPNKerasFeatureExtractor(
...
@@ -129,9 +131,6 @@ class SSDEfficientNetBiFPNKerasFeatureExtractor(
raise
ValueError
(
'EfficientNetBiFPN does not support explicit padding.'
)
raise
ValueError
(
'EfficientNetBiFPN does not support explicit padding.'
)
if
use_depthwise
:
if
use_depthwise
:
raise
ValueError
(
'EfficientNetBiFPN does not support use_depthwise.'
)
raise
ValueError
(
'EfficientNetBiFPN does not support use_depthwise.'
)
if
override_base_feature_extractor_hyperparams
:
raise
ValueError
(
'EfficientNetBiFPN does not support '
'override_base_feature_extractor_hyperparams.'
)
self
.
_bifpn_min_level
=
bifpn_min_level
self
.
_bifpn_min_level
=
bifpn_min_level
self
.
_bifpn_max_level
=
bifpn_max_level
self
.
_bifpn_max_level
=
bifpn_max_level
...
@@ -158,9 +157,15 @@ class SSDEfficientNetBiFPNKerasFeatureExtractor(
...
@@ -158,9 +157,15 @@ class SSDEfficientNetBiFPNKerasFeatureExtractor(
# Initialize the EfficientNet backbone.
# Initialize the EfficientNet backbone.
# Note, this is currently done in the init method rather than in the build
# Note, this is currently done in the init method rather than in the build
# method, since doing so introduces an error which is not well understood.
# method, since doing so introduces an error which is not well understood.
efficientnet_overrides
=
{
'rescale_input'
:
False
}
if
override_base_feature_extractor_hyperparams
:
efficientnet_overrides
[
'weight_decay'
]
=
conv_hyperparams
.
get_regularizer_weight
()
if
(
conv_hyperparams
.
use_sync_batch_norm
()
and
keras_backend
.
is_tpu_strategy
(
tf
.
distribute
.
get_strategy
())):
efficientnet_overrides
[
'batch_norm'
]
=
'tpu'
efficientnet_base
=
efficientnet_model
.
EfficientNet
.
from_name
(
efficientnet_base
=
efficientnet_model
.
EfficientNet
.
from_name
(
model_name
=
self
.
_efficientnet_version
,
model_name
=
self
.
_efficientnet_version
,
overrides
=
efficientnet_overrides
)
overrides
=
{
'rescale_input'
:
False
})
outputs
=
[
efficientnet_base
.
get_layer
(
output_layer_name
).
output
outputs
=
[
efficientnet_base
.
get_layer
(
output_layer_name
).
output
for
output_layer_name
in
self
.
_output_layer_names
]
for
output_layer_name
in
self
.
_output_layer_names
]
self
.
_efficientnet
=
tf
.
keras
.
Model
(
self
.
_efficientnet
=
tf
.
keras
.
Model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment