Commit ce0d936c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove BertSquadLogitsLayer. TF-Hub can use BertSpanLabeler

PiperOrigin-RevId: 295648817
parent 41185bc7
...@@ -186,36 +186,6 @@ def pretrain_model(bert_config, ...@@ -186,36 +186,6 @@ def pretrain_model(bert_config,
return keras_model, transformer_encoder return keras_model, transformer_encoder
class BertSquadLogitsLayer(tf.keras.layers.Layer):
"""Returns a layer that computes custom logits for BERT squad model."""
def __init__(self, initializer=None, **kwargs):
super(BertSquadLogitsLayer, self).__init__(**kwargs)
self.initializer = initializer
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.final_dense = tf.keras.layers.Dense(
units=2, kernel_initializer=self.initializer, name='final_dense')
super(BertSquadLogitsLayer, self).build(unused_input_shapes)
def call(self, inputs):
"""Implements call() for the layer."""
sequence_output = inputs
input_shape = tf_utils.get_shape_list(
sequence_output, name='sequence_output_tensor')
sequence_length = input_shape[1]
num_hidden_units = input_shape[2]
final_hidden_input = tf.reshape(sequence_output, [-1, num_hidden_units])
logits = self.final_dense(final_hidden_input)
logits = tf.reshape(logits, [-1, sequence_length, 2])
logits = tf.transpose(logits, [2, 0, 1])
unstacked_logits = tf.unstack(logits, axis=0)
return unstacked_logits[0], unstacked_logits[1]
def squad_model(bert_config, def squad_model(bert_config,
max_seq_length, max_seq_length,
initializer=None, initializer=None,
...@@ -248,22 +218,18 @@ def squad_model(bert_config, ...@@ -248,22 +218,18 @@ def squad_model(bert_config,
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
core_model = hub.KerasLayer(hub_module_url, trainable=True) core_model = hub.KerasLayer(hub_module_url, trainable=True)
_, sequence_output = core_model( pooled_output, sequence_output = core_model(
[input_word_ids, input_mask, input_type_ids]) [input_word_ids, input_mask, input_type_ids])
bert_encoder = tf.keras.Model(
squad_logits_layer = BertSquadLogitsLayer(
initializer=initializer, name='squad_logits')
start_logits, end_logits = squad_logits_layer(sequence_output)
squad = tf.keras.Model(
inputs={ inputs={
'input_word_ids': input_word_ids, 'input_word_ids': input_word_ids,
'input_mask': input_mask, 'input_mask': input_mask,
'input_type_ids': input_type_ids, 'input_type_ids': input_type_ids,
}, },
outputs=[start_logits, end_logits], outputs=[sequence_output, pooled_output],
name='squad_model') name='core_model')
return squad, core_model return bert_span_labeler.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder
def classifier_model(bert_config, def classifier_model(bert_config,
......
...@@ -183,7 +183,9 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -183,7 +183,9 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
# Prediction always uses float32, even if training uses mixed precision. # Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length']) bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path) logging.info('Restoring checkpoints from %s', checkpoint_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment