Commit 3db5347d authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Allow to not use next sentence labels in pretraining.

PiperOrigin-RevId: 306324960
parent 1025682f
......@@ -54,29 +54,41 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
sentence_labels, sentence_output)
self.add_metric(
next_sentence_accuracy,
name='next_sentence_accuracy',
aggregation='mean')
self.add_metric(
next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels):
if sentence_labels is not None:
next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
sentence_labels, sentence_output)
self.add_metric(
next_sentence_accuracy,
name='next_sentence_accuracy',
aggregation='mean')
if next_sentence_loss is not None:
self.add_metric(
next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self,
lm_output,
sentence_output,
lm_label_ids,
lm_label_weights,
sentence_labels=None):
"""Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32)
sentence_output = tf.cast(sentence_output, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output)
loss = mask_label_loss + sentence_loss
batch_shape = tf.slice(tf.shape(sentence_labels), [0], [1])
if sentence_labels is not None:
sentence_output = tf.cast(sentence_output, tf.float32)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output)
loss = mask_label_loss + sentence_loss
else:
sentence_loss = None
loss = mask_label_loss
batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss)
......@@ -155,7 +167,8 @@ def get_transformer_encoder(bert_config,
def pretrain_model(bert_config,
seq_length,
max_predictions_per_seq,
initializer=None):
initializer=None,
use_next_sentence_label=True):
"""Returns model to be used for pre-training.
Args:
......@@ -164,6 +177,7 @@ def pretrain_model(bert_config,
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining.
initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
Returns:
Pretraining model as well as core BERT submodel from which to save
......@@ -185,8 +199,12 @@ def pretrain_model(bert_config,
shape=(max_predictions_per_seq,),
name='masked_lm_weights',
dtype=tf.int32)
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
if use_next_sentence_label:
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
else:
next_sentence_labels = None
transformer_encoder = get_transformer_encoder(bert_config, seq_length)
if initializer is None:
......@@ -206,17 +224,18 @@ def pretrain_model(bert_config,
vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
masked_lm_weights, next_sentence_labels)
keras_model = tf.keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids,
'masked_lm_positions': masked_lm_positions,
'masked_lm_ids': masked_lm_ids,
'masked_lm_weights': masked_lm_weights,
'next_sentence_labels': next_sentence_labels,
},
outputs=output_loss)
inputs = {
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids,
'masked_lm_positions': masked_lm_positions,
'masked_lm_ids': masked_lm_ids,
'masked_lm_weights': masked_lm_weights,
}
if use_next_sentence_label:
inputs['next_sentence_labels'] = next_sentence_labels
keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
return keras_model, transformer_encoder
......@@ -313,8 +332,7 @@ def classifier_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer(
hub_module_url, trainable=hub_module_trainable)
bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output)
......
......@@ -59,7 +59,8 @@ def create_pretrain_dataset(input_patterns,
max_predictions_per_seq,
batch_size,
is_training=True,
input_pipeline_context=None):
input_pipeline_context=None,
use_next_sentence_label=True):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
'input_ids':
......@@ -74,9 +75,10 @@ def create_pretrain_dataset(input_patterns,
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_weights':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
'next_sentence_labels':
tf.io.FixedLenFeature([1], tf.int64),
}
if use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64)
for input_pattern in input_patterns:
if not tf.io.gfile.glob(input_pattern):
......@@ -118,8 +120,9 @@ def create_pretrain_dataset(input_patterns,
'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'],
'next_sentence_labels': record['next_sentence_labels'],
}
if use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels']
y = record['masked_lm_weights']
......
......@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
......@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS
def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, global_batch_size):
max_predictions_per_seq, global_batch_size,
use_next_sentence_label=True):
"""Returns input dataset from input file string."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
......@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq,
batch_size,
is_training=True,
input_pipeline_context=ctx)
input_pipeline_context=ctx,
use_next_sentence_label=use_next_sentence_label)
return train_dataset
return _dataset_fn
......@@ -95,17 +99,20 @@ def run_customized_training(strategy,
end_lr,
optimizer_type,
input_files,
train_batch_size):
train_batch_size,
use_next_sentence_label=True):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_predictions_per_seq,
train_batch_size)
train_batch_size,
use_next_sentence_label)
def _get_pretrain_model():
"""Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq)
bert_config, max_seq_length, max_predictions_per_seq,
use_next_sentence_label=use_next_sentence_label)
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps,
end_lr, optimizer_type)
......@@ -157,7 +164,8 @@ def run_bert_pretrain(strategy):
FLAGS.end_lr,
FLAGS.optimizer_type,
FLAGS.input_files,
FLAGS.train_batch_size)
FLAGS.train_batch_size,
FLAGS.use_next_sentence_label)
def main(_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment