Commit f32dea32 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[bert_encoder] Found a typo in previous experiment and turns out that the...

[bert_encoder] Found a typo in previous experiment and turns out that the checkpoints are compatible even when wrapped in models.

PiperOrigin-RevId: 406987505
parent 80143eeb
......@@ -571,6 +571,55 @@ class BertEncoderV2CompatibilityTest(tf.test.TestCase):
for key in old_net_outputs:
self.assertAllClose(old_net_outputs[key], new_net_outputs[key])
def test_keras_model_checkpoint_forward_compatible(self):
batch_size = 3
hidden_size = 32
sequence_length = 21
vocab_size = 57
num_types = 7
kwargs = dict(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=None)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
data = dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data)
kwargs["dict_outputs"] = True
old_net = bert_encoder.BertEncoder(**kwargs)
inputs = old_net.inputs
outputs = old_net(inputs)
old_model = tf.keras.Model(inputs=inputs, outputs=outputs)
old_model_outputs = old_model(data)
ckpt = tf.train.Checkpoint(net=old_model)
path = ckpt.save(self.get_temp_dir())
del kwargs["dict_outputs"]
new_net = bert_encoder.BertEncoderV2(**kwargs)
inputs = new_net.inputs
outputs = new_net(inputs)
new_model = tf.keras.Model(inputs=inputs, outputs=outputs)
new_ckpt = tf.train.Checkpoint(net=new_model)
status = new_ckpt.restore(path)
status.assert_existing_objects_matched()
new_model_outputs = new_model(data)
self.assertAllEqual(old_model_outputs.keys(), new_model_outputs.keys())
for key in old_model_outputs:
self.assertAllClose(old_model_outputs[key], new_model_outputs[key])
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment