Commit 553a4f41 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 286634090
parent 5f7bdb11
...@@ -24,6 +24,7 @@ import tensorflow as tf ...@@ -24,6 +24,7 @@ import tensorflow as tf
from typing import Text from typing import Text
from official.nlp import bert_modeling from official.nlp import bert_modeling
from official.nlp import bert_models
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -31,8 +32,7 @@ flags.DEFINE_string("bert_config_file", None, ...@@ -31,8 +32,7 @@ flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.") "Bert configuration file to define core bert layers.")
flags.DEFINE_string("model_checkpoint_path", None, flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.") "File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
"TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None, flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.") "The vocabulary file that the BERT model was trained on.")
...@@ -53,21 +53,23 @@ def create_bert_model(bert_config: bert_modeling.BertConfig): ...@@ -53,21 +53,23 @@ def create_bert_model(bert_config: bert_modeling.BertConfig):
shape=(None,), dtype=tf.int32, name="input_mask") shape=(None,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids") shape=(None,), dtype=tf.int32, name="input_type_ids")
return bert_modeling.get_bert_model( transformer_encoder = bert_models.get_transformer_encoder(
input_word_ids, bert_config, sequence_length=None, float_dtype=tf.float32)
input_mask, sequence_output, pooled_output = transformer_encoder(
input_type_ids, [input_word_ids, input_mask, input_type_ids])
config=bert_config, # To keep consistent with legacy hub modules, the outputs are
name="bert_model", # "pooled_output" and "sequence_output".
float_type=tf.float32) return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[pooled_output, sequence_output]), transformer_encoder
def export_bert_tfhub(bert_config: bert_modeling.BertConfig, def export_bert_tfhub(bert_config: bert_modeling.BertConfig,
model_checkpoint_path: Text, hub_destination: Text, model_checkpoint_path: Text, hub_destination: Text,
vocab_file: Text): vocab_file: Text):
"""Restores a tf.keras.Model and saves for TF-Hub.""" """Restores a tf.keras.Model and saves for TF-Hub."""
core_model = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=core_model) checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed() checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.vocab_file = tf.saved_model.Asset(vocab_file) core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable( core_model.do_lower_case = tf.Variable(
...@@ -79,8 +81,8 @@ def main(_): ...@@ -79,8 +81,8 @@ def main(_):
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
bert_config = bert_modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = bert_modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.export_path, FLAGS.vocab_file) FLAGS.vocab_file)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -39,9 +39,9 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -39,9 +39,9 @@ class ExportTfhubTest(tf.test.TestCase):
max_position_embeddings=128, max_position_embeddings=128,
num_attention_heads=2, num_attention_heads=2,
num_hidden_layers=1) num_hidden_layers=1)
bert_model = export_tfhub.create_bert_model(bert_config) bert_model, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint") model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=bert_model) checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test")) checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir) model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
...@@ -70,10 +70,17 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -70,10 +70,17 @@ class ExportTfhubTest(tf.test.TestCase):
dummy_ids = np.zeros((2, 10), dtype=np.int32) dummy_ids = np.zeros((2, 10), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids]) hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids]) source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16)) self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16)) self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
for source_output, hub_output in zip(source_outputs, hub_outputs): for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy()) self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -134,7 +134,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -134,7 +134,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return final_loss return final_loss
def _get_transformer_encoder(bert_config, def get_transformer_encoder(bert_config,
sequence_length, sequence_length,
float_dtype=tf.float32): float_dtype=tf.float32):
"""Gets a 'TransformerEncoder' object. """Gets a 'TransformerEncoder' object.
...@@ -206,7 +206,7 @@ def pretrain_model(bert_config, ...@@ -206,7 +206,7 @@ def pretrain_model(bert_config,
next_sentence_labels = tf.keras.layers.Input( next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32) shape=(1,), name='next_sentence_labels', dtype=tf.int32)
transformer_encoder = _get_transformer_encoder(bert_config, seq_length) transformer_encoder = get_transformer_encoder(bert_config, seq_length)
if initializer is None: if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal( initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
...@@ -294,7 +294,7 @@ def squad_model(bert_config, ...@@ -294,7 +294,7 @@ def squad_model(bert_config,
initializer = tf.keras.initializers.TruncatedNormal( initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
if not hub_module_url: if not hub_module_url:
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length, bert_encoder = get_transformer_encoder(bert_config, max_seq_length,
float_type) float_type)
return bert_span_labeler.BertSpanLabeler( return bert_span_labeler.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder network=bert_encoder, initializer=initializer), bert_encoder
...@@ -359,7 +359,7 @@ def classifier_model(bert_config, ...@@ -359,7 +359,7 @@ def classifier_model(bert_config,
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
if not hub_module_url: if not hub_module_url:
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length) bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
return bert_classifier.BertClassifier( return bert_classifier.BertClassifier(
bert_encoder, bert_encoder,
num_classes=num_labels, num_classes=num_labels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment