Commit 718b3070 authored by Austin Myers's avatar Austin Myers Committed by TF Object Detection Team
Browse files

Enable name based definition of keras initializers in hyperparams.

PiperOrigin-RevId: 365913722
parent f55a0eb2
......@@ -359,7 +359,7 @@ def _build_initializer(initializer, build_for_keras=False):
operators. If false builds for Slim.
Returns:
tf initializer.
tf initializer or string corresponding to the tf keras initializer name.
Raises:
ValueError: On unknown initializer.
......@@ -415,6 +415,13 @@ def _build_initializer(initializer, build_for_keras=False):
factor=initializer.variance_scaling_initializer.factor,
mode=mode,
uniform=initializer.variance_scaling_initializer.uniform)
if initializer_oneof == 'keras_initializer_by_name':
if build_for_keras:
return initializer.keras_initializer_by_name
else:
raise ValueError(
'Unsupported non-Keras usage of keras_initializer_by_name: {}'.format(
initializer.keras_initializer_by_name))
if initializer_oneof is None:
return None
raise ValueError('Unknown initializer function: {}'.format(
......
......@@ -1030,5 +1030,26 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.64, tol=1e-1)
def test_keras_initializer_by_name(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
keras_initializer_by_name: "glorot_uniform"
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Parse(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
initializer_arg = keras_config.params()['kernel_initializer']
conv_layer = tf.keras.layers.Conv2D(
filters=16, kernel_size=3, **keras_config.params())
self.assertEqual(initializer_arg, 'glorot_uniform')
self.assertIsInstance(conv_layer.kernel_initializer,
type(tf.keras.initializers.get('glorot_uniform')))
if __name__ == '__main__':
tf.test.main()
......@@ -88,6 +88,11 @@ message Initializer {
TruncatedNormalInitializer truncated_normal_initializer = 1;
VarianceScalingInitializer variance_scaling_initializer = 2;
RandomNormalInitializer random_normal_initializer = 3;
// Allows specifying initializers by name, as a string, which will be passed
// directly as an argument during layer construction. Currently, this is
// only supported when using KerasLayerHyperparams, and for valid Keras
// initializers, e.g. `glorot_uniform`, `variance_scaling`, etc.
string keras_initializer_by_name = 4;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment