Commit 73fcb8f5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds Bert tf2 tfhub usage and export script.

As the model code is subject to a major change, we do not release hub module at this moment.

PiperOrigin-RevId: 272688279
parent b40871e4
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to export the BERT core model as a TF-Hub SavedModel."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
from typing import Text
from official.nlp import bert_modeling
FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None,
"TF-Hub SavedModel destination path.")
def create_bert_model(bert_config: bert_modeling.BertConfig):
"""Creates a BERT keras core model from BERT configuration.
Args:
bert_config: A BertConfig` to create the core model.
Returns:
A keras model.
"""
# Adds input layers just as placeholders.
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids")
return bert_modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
name="bert_model",
float_type=tf.float32)
def export_bert_tfhub(bert_config: bert_modeling.BertConfig,
model_checkpoint_path: Text, hub_destination: Text):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=core_model)
checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_):
assert tf.version.VERSION.startswith('2.')
bert_config = bert_modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path)
if __name__ == "__main__":
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests official.nlp.bert.export_tfhub."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from official.nlp import bert_modeling
from official.nlp.bert import export_tfhub
class ExportTfhubTest(tf.test.TestCase):
def test_export_tfhub(self):
# Exports a savedmodel for TF-Hub
bert_config = bert_modeling.BertConfig(
vocab_size=100,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
bert_model = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=bert_model)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination)
# Restores a hub KerasLayer.
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
# Checks the hub KerasLayer.
for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
dummy_ids = np.zeros((2, 10), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
for source_output, hub_output in zip(source_outputs, hub_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
...@@ -55,6 +55,9 @@ flags.DEFINE_string( ...@@ -55,6 +55,9 @@ flags.DEFINE_string(
'to be used for training and evaluation.') 'to be used for training and evaluation.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.') flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
...@@ -112,8 +115,12 @@ def run_customized_training(strategy, ...@@ -112,8 +115,12 @@ def run_customized_training(strategy,
def _get_classifier_model(): def _get_classifier_model():
"""Gets a classifier model.""" """Gets a classifier model."""
classifier_model, core_model = ( classifier_model, core_model = (
bert_models.classifier_model(bert_config, tf.float32, num_classes, bert_models.classifier_model(
max_seq_length)) bert_config,
tf.float32,
num_classes,
max_seq_length,
hub_module_url=FLAGS.hub_module_url))
classifier_model.optimizer = optimization.create_optimizer( classifier_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite': if FLAGS.fp16_implementation == 'graph_rewrite':
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import copy import copy
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp import bert_modeling as modeling from official.nlp import bert_modeling as modeling
...@@ -380,7 +381,8 @@ def classifier_model(bert_config, ...@@ -380,7 +381,8 @@ def classifier_model(bert_config,
float_type, float_type,
num_labels, num_labels,
max_seq_length, max_seq_length,
final_layer_initializer=None): final_layer_initializer=None,
hub_module_url=None):
"""BERT classifier model in functional API style. """BERT classifier model in functional API style.
Construct a Keras model for predicting `num_labels` outputs from an input with Construct a Keras model for predicting `num_labels` outputs from an input with
...@@ -393,6 +395,7 @@ def classifier_model(bert_config, ...@@ -393,6 +395,7 @@ def classifier_model(bert_config,
max_seq_length: integer, the maximum input sequence length. max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted final_layer_initializer: Initializer for final dense layer. Defaulted
TruncatedNormal initializer. TruncatedNormal initializer.
hub_module_url: (Experimental) TF-Hub path/url to Bert module.
Returns: Returns:
Combined prediction model (words, mask, type) -> (one-hot labels) Combined prediction model (words, mask, type) -> (one-hot labels)
...@@ -404,6 +407,10 @@ def classifier_model(bert_config, ...@@ -404,6 +407,10 @@ def classifier_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
if hub_module_url:
bert_model = hub.KerasLayer(hub_module_url, trainable=True)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
else:
bert_model = modeling.get_bert_model( bert_model = modeling.get_bert_model(
input_word_ids, input_word_ids,
input_mask, input_mask,
...@@ -411,6 +418,7 @@ def classifier_model(bert_config, ...@@ -411,6 +418,7 @@ def classifier_model(bert_config,
config=bert_config, config=bert_config,
float_type=float_type) float_type=float_type)
pooled_output = bert_model.outputs[0] pooled_output = bert_model.outputs[0]
if final_layer_initializer is not None: if final_layer_initializer is not None:
initializer = final_layer_initializer initializer = final_layer_initializer
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment