Commit fc227c3b authored by Austin Myers's avatar Austin Myers Committed by TF Object Detection Team
Browse files

Fix Keras hyperparams get_regularizer_weight() to return float.

PiperOrigin-RevId: 365177695
parent 81e456dd
......@@ -176,9 +176,9 @@ class KerasLayerHyperparams(object):
"""
regularizer = self._op_params['kernel_regularizer']
if hasattr(regularizer, 'l1'):
return regularizer.l1
return float(regularizer.l1)
elif hasattr(regularizer, 'l2'):
return regularizer.l2
return float(regularizer.l2)
else:
return None
......
......@@ -558,7 +558,7 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
result = regularizer(tf.constant(weights)).numpy()
self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def test_return_l2_regularizer_weights_keras(self):
def test_return_l2_regularized_weights_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
......@@ -598,6 +598,7 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
conv_hyperparams_proto)
regularizer_weight = keras_config.get_regularizer_weight()
self.assertIsInstance(regularizer_weight, float)
self.assertAlmostEqual(regularizer_weight, 0.5)
def test_return_l2_regularizer_weight_keras(self):
......@@ -618,6 +619,7 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
conv_hyperparams_proto)
regularizer_weight = keras_config.get_regularizer_weight()
self.assertIsInstance(regularizer_weight, float)
self.assertAlmostEqual(regularizer_weight, 0.25)
def test_return_undefined_regularizer_weight_keras(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment