Commit d48d036a authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to use KerasBERT for squad task.

PiperOrigin-RevId: 279873276
parent 146a37c6
...@@ -177,7 +177,6 @@ def create_classifier_dataset(file_path, ...@@ -177,7 +177,6 @@ def create_classifier_dataset(file_path,
def create_squad_dataset(file_path, seq_length, batch_size, is_training=True): def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
"""Creates input dataset from (tf)records files for train/eval.""" """Creates input dataset from (tf)records files for train/eval."""
name_to_features = { name_to_features = {
'unique_ids': tf.io.FixedLenFeature([], tf.int64),
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
...@@ -185,15 +184,22 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training=True): ...@@ -185,15 +184,22 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
if is_training: if is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
input_fn = file_based_input_fn_builder(file_path, name_to_features) input_fn = file_based_input_fn_builder(file_path, name_to_features)
dataset = input_fn() dataset = input_fn()
def _select_data_from_record(record): def _select_data_from_record(record):
"""Dispatches record to features and labels."""
x, y = {}, {} x, y = {}, {}
for name, tensor in record.items(): for name, tensor in record.items():
if name in ('start_positions', 'end_positions'): if name in ('start_positions', 'end_positions'):
y[name] = tensor y[name] = tensor
elif name == 'input_ids':
x['input_word_ids'] = tensor
elif name == 'segment_ids':
x['input_type_ids'] = tensor
else: else:
x[name] = tensor x[name] = tensor
return (x, y) return (x, y)
......
...@@ -80,6 +80,10 @@ flags.DEFINE_integer( ...@@ -80,6 +80,10 @@ flags.DEFINE_integer(
'max_answer_length', 30, 'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed ' 'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.') 'because the start and end predictions are not conditioned on one another.')
flags.DEFINE_bool(
'use_keras_bert_for_squad', False, 'Whether to use keras BERT for squad '
'task. Note that when the FLAG "hub_module_url" is specified, '
'"use_keras_bert_for_squad" cannot be True.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
...@@ -108,7 +112,7 @@ def get_loss_fn(loss_factor=1.0): ...@@ -108,7 +112,7 @@ def get_loss_fn(loss_factor=1.0):
def _loss_fn(labels, model_outputs): def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions'] start_positions = labels['start_positions']
end_positions = labels['end_positions'] end_positions = labels['end_positions']
_, start_logits, end_logits = model_outputs start_logits, end_logits = model_outputs
return squad_loss_fn( return squad_loss_fn(
start_positions, start_positions,
end_positions, end_positions,
...@@ -147,7 +151,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -147,7 +151,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
# Prediction always uses float32, even if training uses mixed precision. # Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32) bert_config, input_meta_data['max_seq_length'], float_type=tf.float32,
use_keras_bert=FLAGS.use_keras_bert_for_squad)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path) logging.info('Restoring checkpoints from %s', checkpoint_path)
...@@ -161,7 +166,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -161,7 +166,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
def _replicated_step(inputs): def _replicated_step(inputs):
"""Replicated prediction calculation.""" """Replicated prediction calculation."""
x, _ = inputs x, _ = inputs
unique_ids, start_logits, end_logits = squad_model(x, training=False) unique_ids = x.pop('unique_ids')
start_logits, end_logits = squad_model(x, training=False)
return dict( return dict(
unique_ids=unique_ids, unique_ids=unique_ids,
start_logits=start_logits, start_logits=start_logits,
...@@ -216,7 +222,8 @@ def train_squad(strategy, ...@@ -216,7 +222,8 @@ def train_squad(strategy,
bert_config, bert_config,
max_seq_length, max_seq_length,
float_type=tf.float16 if use_float16 else tf.float32, float_type=tf.float16 if use_float16 else tf.float32,
hub_module_url=FLAGS.hub_module_url) hub_module_url=FLAGS.hub_module_url,
use_keras_bert=FLAGS.use_keras_bert_for_squad)
squad_model.optimizer = optimization.create_optimizer( squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16: if use_float16:
...@@ -340,7 +347,8 @@ def export_squad(model_export_path, input_meta_data): ...@@ -340,7 +347,8 @@ def export_squad(model_export_path, input_meta_data):
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32) bert_config, input_meta_data['max_seq_length'], float_type=tf.float32,
use_keras_bert=FLAGS.use_keras_bert_for_squad)
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir) model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
......
...@@ -26,6 +26,7 @@ from official.modeling import tf_utils ...@@ -26,6 +26,7 @@ from official.modeling import tf_utils
from official.nlp import bert_modeling as modeling from official.nlp import bert_modeling as modeling
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier from official.nlp.modeling.networks import bert_classifier
from official.nlp.modeling.networks import bert_span_labeler
def gather_indexes(sequence_tensor, positions): def gather_indexes(sequence_tensor, positions):
...@@ -224,6 +225,32 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -224,6 +225,32 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return final_loss return final_loss
def _get_transformer_encoder(bert_config, sequence_length):
"""Gets a 'TransformerEncoder' object.
Args:
bert_config: A 'modeling.BertConfig' object.
sequence_length: Maximum sequence length of the training data.
Returns:
A networks.TransformerEncoder object.
"""
return networks.TransformerEncoder(
vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers,
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
activation=tf_utils.get_activation('gelu'),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=sequence_length,
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
def pretrain_model(bert_config, def pretrain_model(bert_config,
seq_length, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
...@@ -333,7 +360,8 @@ def squad_model(bert_config, ...@@ -333,7 +360,8 @@ def squad_model(bert_config,
max_seq_length, max_seq_length,
float_type, float_type,
initializer=None, initializer=None,
hub_module_url=None): hub_module_url=None,
use_keras_bert=False):
"""Returns BERT Squad model along with core BERT model to import weights. """Returns BERT Squad model along with core BERT model to import weights.
Args: Args:
...@@ -342,19 +370,31 @@ def squad_model(bert_config, ...@@ -342,19 +370,31 @@ def squad_model(bert_config,
float_type: tf.dtype, tf.float32 or tf.bfloat16. float_type: tf.dtype, tf.float32 or tf.bfloat16.
initializer: Initializer for weights in BertSquadLogitsLayer. initializer: Initializer for weights in BertSquadLogitsLayer.
hub_module_url: TF-Hub path/url to Bert module. hub_module_url: TF-Hub path/url to Bert module.
use_keras_bert: Whether to use keras BERT. Note that when the above
'hub_module_url' is specified, 'use_keras_bert' cannot be True.
Returns: Returns:
Two tensors, start logits and end logits, [batch x sequence length]. A tuple of (1) keras model that outputs start logits and end logits and
(2) the core BERT transformer encoder.
Raises:
ValueError: When 'hub_module_url' is specified and 'use_keras_bert' is True.
""" """
unique_ids = tf.keras.layers.Input( if hub_module_url and use_keras_bert:
shape=(1,), dtype=tf.int32, name='unique_ids') raise ValueError(
'Cannot use hub_module_url and keras BERT at the same time.')
if use_keras_bert:
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length)
return bert_span_labeler.BertSpanLabeler(
network=bert_encoder), bert_encoder
input_word_ids = tf.keras.layers.Input( input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input( input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='segment_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
if hub_module_url: if hub_module_url:
core_model = hub.KerasLayer(hub_module_url, trainable=True) core_model = hub.KerasLayer(hub_module_url, trainable=True)
_, sequence_output = core_model( _, sequence_output = core_model(
...@@ -383,12 +423,11 @@ def squad_model(bert_config, ...@@ -383,12 +423,11 @@ def squad_model(bert_config,
squad = tf.keras.Model( squad = tf.keras.Model(
inputs={ inputs={
'unique_ids': unique_ids, 'input_word_ids': input_word_ids,
'input_ids': input_word_ids,
'input_mask': input_mask, 'input_mask': input_mask,
'segment_ids': input_type_ids, 'input_type_ids': input_type_ids,
}, },
outputs=[unique_ids, start_logits, end_logits], outputs=[start_logits, end_logits],
name='squad_model') name='squad_model')
return squad, core_model return squad, core_model
...@@ -424,20 +463,7 @@ def classifier_model(bert_config, ...@@ -424,20 +463,7 @@ def classifier_model(bert_config,
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
if not hub_module_url: if not hub_module_url:
bert_encoder = networks.TransformerEncoder( bert_encoder = _get_transformer_encoder(bert_config, max_seq_length)
vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers,
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
activation=tf_utils.get_activation('gelu'),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=max_seq_length,
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
return bert_classifier.BertClassifier( return bert_classifier.BertClassifier(
bert_encoder, bert_encoder,
num_classes=num_labels, num_classes=num_labels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment