Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
718b3070
Commit
718b3070
authored
Mar 30, 2021
by
Austin Myers
Committed by
TF Object Detection Team
Mar 30, 2021
Browse files
Enable name based definition of keras initializers in hyperparams.
PiperOrigin-RevId: 365913722
parent
f55a0eb2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
1 deletion
+34
-1
research/object_detection/builders/hyperparams_builder.py
research/object_detection/builders/hyperparams_builder.py
+8
-1
research/object_detection/builders/hyperparams_builder_test.py
...rch/object_detection/builders/hyperparams_builder_test.py
+21
-0
research/object_detection/protos/hyperparams.proto
research/object_detection/protos/hyperparams.proto
+5
-0
No files found.
research/object_detection/builders/hyperparams_builder.py
View file @
718b3070
...
...
@@ -359,7 +359,7 @@ def _build_initializer(initializer, build_for_keras=False):
operators. If false builds for Slim.
Returns:
tf initializer.
tf initializer
or string corresponding to the tf keras initializer name
.
Raises:
ValueError: On unknown initializer.
...
...
@@ -415,6 +415,13 @@ def _build_initializer(initializer, build_for_keras=False):
factor
=
initializer
.
variance_scaling_initializer
.
factor
,
mode
=
mode
,
uniform
=
initializer
.
variance_scaling_initializer
.
uniform
)
if
initializer_oneof
==
'keras_initializer_by_name'
:
if
build_for_keras
:
return
initializer
.
keras_initializer_by_name
else
:
raise
ValueError
(
'Unsupported non-Keras usage of keras_initializer_by_name: {}'
.
format
(
initializer
.
keras_initializer_by_name
))
if
initializer_oneof
is
None
:
return
None
raise
ValueError
(
'Unknown initializer function: {}'
.
format
(
...
...
research/object_detection/builders/hyperparams_builder_test.py
View file @
718b3070
...
...
@@ -1030,5 +1030,26 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
self
.
_assert_variance_in_range
(
initializer
,
shape
=
[
100
,
40
],
variance
=
0.64
,
tol
=
1e-1
)
def
test_keras_initializer_by_name
(
self
):
conv_hyperparams_text_proto
=
"""
regularizer {
l2_regularizer {
}
}
initializer {
keras_initializer_by_name: "glorot_uniform"
}
"""
conv_hyperparams_proto
=
hyperparams_pb2
.
Hyperparams
()
text_format
.
Parse
(
conv_hyperparams_text_proto
,
conv_hyperparams_proto
)
keras_config
=
hyperparams_builder
.
KerasLayerHyperparams
(
conv_hyperparams_proto
)
initializer_arg
=
keras_config
.
params
()[
'kernel_initializer'
]
conv_layer
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
16
,
kernel_size
=
3
,
**
keras_config
.
params
())
self
.
assertEqual
(
initializer_arg
,
'glorot_uniform'
)
self
.
assertIsInstance
(
conv_layer
.
kernel_initializer
,
type
(
tf
.
keras
.
initializers
.
get
(
'glorot_uniform'
)))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/protos/hyperparams.proto
View file @
718b3070
...
...
@@ -88,6 +88,11 @@ message Initializer {
TruncatedNormalInitializer
truncated_normal_initializer
=
1
;
VarianceScalingInitializer
variance_scaling_initializer
=
2
;
RandomNormalInitializer
random_normal_initializer
=
3
;
// Allows specifying initializers by name, as a string, which will be passed
// directly as an argument during layer construction. Currently, this is
// only supported when using KerasLayerHyperparams, and for valid Keras
// initializers, e.g. `glorot_uniform`, `variance_scaling`, etc.
string
keras_initializer_by_name
=
4
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment