Commit 223f5148 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Bert Hub usage for squad

PiperOrigin-RevId: 274090672
parent 9f8e0646
......@@ -60,6 +60,9 @@ def define_common_bert_flags():
'use_keras_compile_fit', False,
'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.')
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
# Adds flags for mixed precision training.
flags_core.define_performance(
......
......@@ -56,9 +56,6 @@ flags.DEFINE_string(
'to be used for training and evaluation.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
common_flags.define_common_bert_flags()
......
......@@ -218,7 +218,8 @@ def train_squad(strategy,
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
float_type=tf.float16 if use_float16 else tf.float32)
float_type=tf.float16 if use_float16 else tf.float32,
hub_module_url=FLAGS.hub_module_url)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
......
......@@ -128,8 +128,7 @@ class BertPretrainLayer(tf.keras.layers.Layer):
sequence_output = unpacked_inputs[1]
masked_lm_positions = unpacked_inputs[2]
mask_lm_input_tensor = gather_indexes(
sequence_output, masked_lm_positions)
mask_lm_input_tensor = gather_indexes(sequence_output, masked_lm_positions)
lm_output = self.lm_dense(mask_lm_input_tensor)
lm_output = self.lm_layer_norm(lm_output)
lm_output = tf.matmul(lm_output, self.embedding_table, transpose_b=True)
......@@ -325,7 +324,11 @@ class BertSquadLogitsLayer(tf.keras.layers.Layer):
return unstacked_logits[0], unstacked_logits[1]
def squad_model(bert_config, max_seq_length, float_type, initializer=None):
def squad_model(bert_config,
max_seq_length,
float_type,
initializer=None,
hub_module_url=None):
"""Returns BERT Squad model along with core BERT model to import weights.
Args:
......@@ -333,6 +336,7 @@ def squad_model(bert_config, max_seq_length, float_type, initializer=None):
max_seq_length: integer, the maximum input sequence length.
float_type: tf.dtype, tf.float32 or tf.bfloat16.
initializer: Initializer for weights in BertSquadLogitsLayer.
hub_module_url: TF-Hub path/url to Bert module.
Returns:
Two tensors, start logits and end logits, [batch x sequence length].
......@@ -346,17 +350,26 @@ def squad_model(bert_config, max_seq_length, float_type, initializer=None):
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='segment_ids')
core_model = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
name='bert_model',
float_type=float_type)
# `BertSquadModel` only uses the sequnce_output which
# has dimensionality (batch_size, sequence_length, num_hidden).
sequence_output = core_model.outputs[1]
if hub_module_url:
core_model = hub.KerasLayer(
hub_module_url,
trainable=True)
_, sequence_output = core_model(
[input_word_ids, input_mask, input_type_ids])
# Sets the shape manually due to a bug in TF shape inference.
# TODO(hongkuny): remove this once shape inference is correct.
sequence_output.set_shape((None, max_seq_length, bert_config.hidden_size))
else:
core_model = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
name='bert_model',
float_type=float_type)
# `BertSquadModel` only uses the sequnce_output which
# has dimensionality (batch_size, sequence_length, num_hidden).
sequence_output = core_model.outputs[1]
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
......@@ -395,7 +408,7 @@ def classifier_model(bert_config,
max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted
TruncatedNormal initializer.
hub_module_url: (Experimental) TF-Hub path/url to Bert module.
hub_module_url: TF-Hub path/url to Bert module.
Returns:
Combined prediction model (words, mask, type) -> (one-hot labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment