Commit 6a5ce811 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to export ALBERT to tf2 hub.

PiperOrigin-RevId: 288757693
parent 9302933b
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from typing import Text from typing import Optional, Text
from official.nlp import bert_modeling from official.nlp import bert_modeling
from official.nlp import bert_models from official.nlp import bert_models
...@@ -35,6 +35,13 @@ flags.DEFINE_string("model_checkpoint_path", None, ...@@ -35,6 +35,13 @@ flags.DEFINE_string("model_checkpoint_path", None,
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.") flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None, flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.") "The vocabulary file that the BERT model was trained on.")
flags.DEFINE_string("sp_model_file", None,
"The sentence piece model file that the ALBERT model was "
"trained on.")
flags.DEFINE_enum(
"model_type", "bert", ["bert", "albert"],
"Specifies the type of the model. "
"If 'bert', will use canonical BERT; if 'albert', will use ALBERT model.")
def create_bert_model(bert_config: bert_modeling.BertConfig): def create_bert_model(bert_config: bert_modeling.BertConfig):
...@@ -65,24 +72,39 @@ def create_bert_model(bert_config: bert_modeling.BertConfig): ...@@ -65,24 +72,39 @@ def create_bert_model(bert_config: bert_modeling.BertConfig):
def export_bert_tfhub(bert_config: bert_modeling.BertConfig, def export_bert_tfhub(bert_config: bert_modeling.BertConfig,
model_checkpoint_path: Text, hub_destination: Text, model_checkpoint_path: Text,
vocab_file: Text): hub_destination: Text,
vocab_file: Optional[Text] = None,
sp_model_file: Optional[Text] = None):
"""Restores a tf.keras.Model and saves for TF-Hub.""" """Restores a tf.keras.Model and saves for TF-Hub."""
core_model, encoder = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder) checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed() checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable( if isinstance(bert_config, bert_modeling.AlbertConfig):
"uncased" in vocab_file, trainable=False) if not sp_model_file:
raise ValueError("sp_model_file is required.")
core_model.sp_model_file = tf.saved_model.Asset(sp_model_file)
else:
assert isinstance(bert_config, bert_modeling.BertConfig)
if not vocab_file:
raise ValueError("vocab_file is required.")
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(
"uncased" in vocab_file, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf") core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
config_cls = {
bert_config = bert_modeling.BertConfig.from_json_file(FLAGS.bert_config_file) "bert": bert_modeling.BertConfig,
"albert": bert_modeling.AlbertConfig,
}
bert_config = config_cls[FLAGS.model_type].from_json_file(
FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path, export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file) FLAGS.vocab_file, FLAGS.sp_model_file)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -82,6 +82,57 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -82,6 +82,57 @@ class ExportTfhubTest(tf.test.TestCase):
self.assertAllClose(source_output.numpy(), hub_output.numpy()) self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy()) self.assertAllClose(source_output.numpy(), encoder_output.numpy())
def test_export_albert_tfhub(self):
# Exports a savedmodel for TF-Hub
bert_config = bert_modeling.AlbertConfig(
vocab_size=100,
embedding_size=8,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
bert_model, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
sp_model_file = os.path.join(self.get_temp_dir(), "sp_tokenizer.model")
with tf.io.gfile.GFile(sp_model_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination, sp_model_file=sp_model_file)
# Restores a hub KerasLayer.
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
if hasattr(hub_layer, "resolved_object"):
with tf.io.gfile.GFile(
hub_layer.resolved_object.sp_model_file.asset_path.numpy()) as f:
self.assertEqual("dummy content", f.read())
# Checks the hub KerasLayer.
for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
dummy_ids = np.zeros((2, 10), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment