"sgl-kernel/vscode:/vscode.git/clone" did not exist on "bf669606eb84e12dc1ecf15b23c1eedab204d660"
Commit d5d087ba authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Add the output projects layer in the checkpoint items of MultiClsHead.

PiperOrigin-RevId: 361916017
parent aa7dbd59
...@@ -158,5 +158,6 @@ class MultiClsHeads(tf.keras.layers.Layer): ...@@ -158,5 +158,6 @@ class MultiClsHeads(tf.keras.layers.Layer):
@property @property
def checkpoint_items(self): def checkpoint_items(self):
# TODO(hongkuny): add output projects to the checkpoint items. items = {self.dense.name: self.dense}
return {self.dense.name: self.dense} items.update({v.name: v for v in self.out_projs})
return items
...@@ -48,7 +48,7 @@ class MultiClsHeadsTest(tf.test.TestCase): ...@@ -48,7 +48,7 @@ class MultiClsHeadsTest(tf.test.TestCase):
self.assertAllClose(outputs["foo"], [[0., 0.], [0., 0.]]) self.assertAllClose(outputs["foo"], [[0., 0.], [0., 0.]])
self.assertAllClose(outputs["bar"], [[0., 0., 0.], [0., 0., 0.]]) self.assertAllClose(outputs["bar"], [[0., 0., 0.], [0., 0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(), self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"]) ["pooler_dense", "foo", "bar"])
def test_layer_serialization(self): def test_layer_serialization(self):
cls_list = [("foo", 2), ("bar", 3)] cls_list = [("foo", 2), ("bar", 3)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment