Commit 223f5148 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Bert Hub usage for squad

PiperOrigin-RevId: 274090672
parent 9f8e0646
...@@ -60,6 +60,9 @@ def define_common_bert_flags(): ...@@ -60,6 +60,9 @@ def define_common_bert_flags():
'use_keras_compile_fit', False, 'use_keras_compile_fit', False,
'If True, uses Keras compile/fit() API for training logic. Otherwise ' 'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.') 'use custom training loop.')
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
# Adds flags for mixed precision training. # Adds flags for mixed precision training.
flags_core.define_performance( flags_core.define_performance(
......
...@@ -56,9 +56,6 @@ flags.DEFINE_string( ...@@ -56,9 +56,6 @@ flags.DEFINE_string(
'to be used for training and evaluation.') 'to be used for training and evaluation.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.') flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
......
...@@ -218,7 +218,8 @@ def train_squad(strategy, ...@@ -218,7 +218,8 @@ def train_squad(strategy,
squad_model, core_model = bert_models.squad_model( squad_model, core_model = bert_models.squad_model(
bert_config, bert_config,
max_seq_length, max_seq_length,
float_type=tf.float16 if use_float16 else tf.float32) float_type=tf.float16 if use_float16 else tf.float32,
hub_module_url=FLAGS.hub_module_url)
squad_model.optimizer = optimization.create_optimizer( squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16: if use_float16:
......
...@@ -128,8 +128,7 @@ class BertPretrainLayer(tf.keras.layers.Layer): ...@@ -128,8 +128,7 @@ class BertPretrainLayer(tf.keras.layers.Layer):
sequence_output = unpacked_inputs[1] sequence_output = unpacked_inputs[1]
masked_lm_positions = unpacked_inputs[2] masked_lm_positions = unpacked_inputs[2]
mask_lm_input_tensor = gather_indexes( mask_lm_input_tensor = gather_indexes(sequence_output, masked_lm_positions)
sequence_output, masked_lm_positions)
lm_output = self.lm_dense(mask_lm_input_tensor) lm_output = self.lm_dense(mask_lm_input_tensor)
lm_output = self.lm_layer_norm(lm_output) lm_output = self.lm_layer_norm(lm_output)
lm_output = tf.matmul(lm_output, self.embedding_table, transpose_b=True) lm_output = tf.matmul(lm_output, self.embedding_table, transpose_b=True)
...@@ -325,7 +324,11 @@ class BertSquadLogitsLayer(tf.keras.layers.Layer): ...@@ -325,7 +324,11 @@ class BertSquadLogitsLayer(tf.keras.layers.Layer):
return unstacked_logits[0], unstacked_logits[1] return unstacked_logits[0], unstacked_logits[1]
def squad_model(bert_config, max_seq_length, float_type, initializer=None): def squad_model(bert_config,
max_seq_length,
float_type,
initializer=None,
hub_module_url=None):
"""Returns BERT Squad model along with core BERT model to import weights. """Returns BERT Squad model along with core BERT model to import weights.
Args: Args:
...@@ -333,6 +336,7 @@ def squad_model(bert_config, max_seq_length, float_type, initializer=None): ...@@ -333,6 +336,7 @@ def squad_model(bert_config, max_seq_length, float_type, initializer=None):
max_seq_length: integer, the maximum input sequence length. max_seq_length: integer, the maximum input sequence length.
float_type: tf.dtype, tf.float32 or tf.bfloat16. float_type: tf.dtype, tf.float32 or tf.bfloat16.
initializer: Initializer for weights in BertSquadLogitsLayer. initializer: Initializer for weights in BertSquadLogitsLayer.
hub_module_url: TF-Hub path/url to Bert module.
Returns: Returns:
Two tensors, start logits and end logits, [batch x sequence length]. Two tensors, start logits and end logits, [batch x sequence length].
...@@ -346,17 +350,26 @@ def squad_model(bert_config, max_seq_length, float_type, initializer=None): ...@@ -346,17 +350,26 @@ def squad_model(bert_config, max_seq_length, float_type, initializer=None):
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='segment_ids') shape=(max_seq_length,), dtype=tf.int32, name='segment_ids')
core_model = modeling.get_bert_model( if hub_module_url:
input_word_ids, core_model = hub.KerasLayer(
input_mask, hub_module_url,
input_type_ids, trainable=True)
config=bert_config, _, sequence_output = core_model(
name='bert_model', [input_word_ids, input_mask, input_type_ids])
float_type=float_type) # Sets the shape manually due to a bug in TF shape inference.
# TODO(hongkuny): remove this once shape inference is correct.
# `BertSquadModel` only uses the sequnce_output which sequence_output.set_shape((None, max_seq_length, bert_config.hidden_size))
# has dimensionality (batch_size, sequence_length, num_hidden). else:
sequence_output = core_model.outputs[1] core_model = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
name='bert_model',
float_type=float_type)
# `BertSquadModel` only uses the sequnce_output which
# has dimensionality (batch_size, sequence_length, num_hidden).
sequence_output = core_model.outputs[1]
if initializer is None: if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal( initializer = tf.keras.initializers.TruncatedNormal(
...@@ -395,7 +408,7 @@ def classifier_model(bert_config, ...@@ -395,7 +408,7 @@ def classifier_model(bert_config,
max_seq_length: integer, the maximum input sequence length. max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted final_layer_initializer: Initializer for final dense layer. Defaulted
TruncatedNormal initializer. TruncatedNormal initializer.
hub_module_url: (Experimental) TF-Hub path/url to Bert module. hub_module_url: TF-Hub path/url to Bert module.
Returns: Returns:
Combined prediction model (words, mask, type) -> (one-hot labels) Combined prediction model (words, mask, type) -> (one-hot labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment