Commit 8a1dbbad authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Fix get_config() of MaskedSoftmax.

PiperOrigin-RevId: 315767729
parent 4cd908ef
......@@ -64,6 +64,9 @@ class MaskedSoftmax(tf.keras.layers.Layer):
scores, axis=self._normalization_axes, keepdims=True))
def get_config(self):
config = {'mask_expansion_axes': self._mask_expansion_axes}
config = {
'mask_expansion_axes': self._mask_expansion_axes,
'normalization_axes': self._normalization_axes
}
base_config = super(MaskedSoftmax, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -105,6 +105,15 @@ class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
is_zeros = np.greater(output_data, 0)
self.assertAllEqual(expected_zeros, is_zeros)
def test_serialize_deserialize(self):
test_layer = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=[6, 7])
new_layer = masked_softmax.MaskedSoftmax.from_config(
test_layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment