Commit fc227c3b authored by Austin Myers's avatar Austin Myers Committed by TF Object Detection Team
Browse files

Fix Keras hyperparams get_regularizer_weight() to return float.

PiperOrigin-RevId: 365177695
parent 81e456dd
...@@ -176,9 +176,9 @@ class KerasLayerHyperparams(object): ...@@ -176,9 +176,9 @@ class KerasLayerHyperparams(object):
""" """
regularizer = self._op_params['kernel_regularizer'] regularizer = self._op_params['kernel_regularizer']
if hasattr(regularizer, 'l1'): if hasattr(regularizer, 'l1'):
return regularizer.l1 return float(regularizer.l1)
elif hasattr(regularizer, 'l2'): elif hasattr(regularizer, 'l2'):
return regularizer.l2 return float(regularizer.l2)
else: else:
return None return None
......
...@@ -558,7 +558,7 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase): ...@@ -558,7 +558,7 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
result = regularizer(tf.constant(weights)).numpy() result = regularizer(tf.constant(weights)).numpy()
self.assertAllClose(np.abs(weights).sum() * 0.5, result) self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def test_return_l2_regularizer_weights_keras(self): def test_return_l2_regularized_weights_keras(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
regularizer { regularizer {
l2_regularizer { l2_regularizer {
...@@ -598,6 +598,7 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase): ...@@ -598,6 +598,7 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
conv_hyperparams_proto) conv_hyperparams_proto)
regularizer_weight = keras_config.get_regularizer_weight() regularizer_weight = keras_config.get_regularizer_weight()
self.assertIsInstance(regularizer_weight, float)
self.assertAlmostEqual(regularizer_weight, 0.5) self.assertAlmostEqual(regularizer_weight, 0.5)
def test_return_l2_regularizer_weight_keras(self): def test_return_l2_regularizer_weight_keras(self):
...@@ -618,6 +619,7 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase): ...@@ -618,6 +619,7 @@ class KerasHyperparamsBuilderTest(tf.test.TestCase):
conv_hyperparams_proto) conv_hyperparams_proto)
regularizer_weight = keras_config.get_regularizer_weight() regularizer_weight = keras_config.get_regularizer_weight()
self.assertIsInstance(regularizer_weight, float)
self.assertAlmostEqual(regularizer_weight, 0.25) self.assertAlmostEqual(regularizer_weight, 0.25)
def test_return_undefined_regularizer_weight_keras(self): def test_return_undefined_regularizer_weight_keras(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment