"vscode:/vscode.git/clone" did not exist on "4cb5a5235e49f08e9f47165fb7f95e35145fce43"
Commit 553a4f41 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 286634090
parent 5f7bdb11
......@@ -24,6 +24,7 @@ import tensorflow as tf
from typing import Text
from official.nlp import bert_modeling
from official.nlp import bert_models
FLAGS = flags.FLAGS
......@@ -31,8 +32,7 @@ flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None,
"TF-Hub SavedModel destination path.")
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
......@@ -53,21 +53,23 @@ def create_bert_model(bert_config: bert_modeling.BertConfig):
shape=(None,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids")
return bert_modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
name="bert_model",
float_type=tf.float32)
transformer_encoder = bert_models.get_transformer_encoder(
bert_config, sequence_length=None, float_dtype=tf.float32)
sequence_output, pooled_output = transformer_encoder(
[input_word_ids, input_mask, input_type_ids])
# To keep consistent with legacy hub modules, the outputs are
# "pooled_output" and "sequence_output".
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[pooled_output, sequence_output]), transformer_encoder
def export_bert_tfhub(bert_config: bert_modeling.BertConfig,
model_checkpoint_path: Text, hub_destination: Text,
vocab_file: Text):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=core_model)
core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(
......@@ -79,8 +81,8 @@ def main(_):
assert tf.version.VERSION.startswith('2.')
bert_config = bert_modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.vocab_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file)
if __name__ == "__main__":
......
......@@ -39,9 +39,9 @@ class ExportTfhubTest(tf.test.TestCase):
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
bert_model = export_tfhub.create_bert_model(bert_config)
bert_model, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=bert_model)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
......@@ -70,10 +70,17 @@ class ExportTfhubTest(tf.test.TestCase):
dummy_ids = np.zeros((2, 10), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
for source_output, hub_output in zip(source_outputs, hub_outputs):
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
if __name__ == "__main__":
......
......@@ -134,7 +134,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return final_loss
def _get_transformer_encoder(bert_config,
def get_transformer_encoder(bert_config,
sequence_length,
float_dtype=tf.float32):
"""Gets a 'TransformerEncoder' object.
......@@ -206,7 +206,7 @@ def pretrain_model(bert_config,
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
transformer_encoder = _get_transformer_encoder(bert_config, seq_length)
transformer_encoder = get_transformer_encoder(bert_config, seq_length)
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
......@@ -294,7 +294,7 @@ def squad_model(bert_config,
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
if not hub_module_url:
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length,
bert_encoder = get_transformer_encoder(bert_config, max_seq_length,
float_type)
return bert_span_labeler.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder
......@@ -359,7 +359,7 @@ def classifier_model(bert_config,
stddev=bert_config.initializer_range)
if not hub_module_url:
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length)
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
return bert_classifier.BertClassifier(
bert_encoder,
num_classes=num_labels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment