"vscode:/vscode.git/clone" did not exist on "f98366604b23e331422bf3c62d4e7410ae4fab87"
Commit 64c88b24 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Add the output projects layer in the checkpoint items of MultiClsHead.

PiperOrigin-RevId: 361916017
parent e0e8aa4a
......@@ -158,5 +158,6 @@ class MultiClsHeads(tf.keras.layers.Layer):
@property
def checkpoint_items(self):
# TODO(hongkuny): add output projects to the checkpoint items.
return {self.dense.name: self.dense}
items = {self.dense.name: self.dense}
items.update({v.name: v for v in self.out_projs})
return items
......@@ -48,7 +48,7 @@ class MultiClsHeadsTest(tf.test.TestCase):
self.assertAllClose(outputs["foo"], [[0., 0.], [0., 0.]])
self.assertAllClose(outputs["bar"], [[0., 0., 0.], [0., 0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"])
["pooler_dense", "foo", "bar"])
def test_layer_serialization(self):
cls_list = [("foo", 2), ("bar", 3)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment