Commit 64c88b24 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Add the output projects layer in the checkpoint items of MultiClsHead.

PiperOrigin-RevId: 361916017
parent e0e8aa4a
...@@ -158,5 +158,6 @@ class MultiClsHeads(tf.keras.layers.Layer): ...@@ -158,5 +158,6 @@ class MultiClsHeads(tf.keras.layers.Layer):
@property @property
def checkpoint_items(self): def checkpoint_items(self):
# TODO(hongkuny): add output projects to the checkpoint items. items = {self.dense.name: self.dense}
return {self.dense.name: self.dense} items.update({v.name: v for v in self.out_projs})
return items
...@@ -48,7 +48,7 @@ class MultiClsHeadsTest(tf.test.TestCase): ...@@ -48,7 +48,7 @@ class MultiClsHeadsTest(tf.test.TestCase):
self.assertAllClose(outputs["foo"], [[0., 0.], [0., 0.]]) self.assertAllClose(outputs["foo"], [[0., 0.], [0., 0.]])
self.assertAllClose(outputs["bar"], [[0., 0., 0.], [0., 0., 0.]]) self.assertAllClose(outputs["bar"], [[0., 0., 0.], [0., 0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(), self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"]) ["pooler_dense", "foo", "bar"])
def test_layer_serialization(self): def test_layer_serialization(self):
cls_list = [("foo", 2), ("bar", 3)] cls_list = [("foo", 2), ("bar", 3)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment