Commit 3db5347d authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Allow to not use next sentence labels in pretraining.

PiperOrigin-RevId: 306324960
parent 1025682f
...@@ -54,6 +54,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -54,6 +54,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean') self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
if sentence_labels is not None:
next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy( next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
sentence_labels, sentence_output) sentence_labels, sentence_output)
self.add_metric( self.add_metric(
...@@ -61,22 +62,33 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -61,22 +62,33 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
name='next_sentence_accuracy', name='next_sentence_accuracy',
aggregation='mean') aggregation='mean')
if next_sentence_loss is not None:
self.add_metric( self.add_metric(
next_sentence_loss, name='next_sentence_loss', aggregation='mean') next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights, def call(self,
sentence_labels): lm_output,
sentence_output,
lm_label_ids,
lm_label_weights,
sentence_labels=None):
"""Implements call() for the layer.""" """Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32) lm_output = tf.cast(lm_output, tf.float32)
sentence_output = tf.cast(sentence_output, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
if sentence_labels is not None:
sentence_output = tf.cast(sentence_output, tf.float32)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss( sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output) labels=sentence_labels, predictions=sentence_output)
loss = mask_label_loss + sentence_loss loss = mask_label_loss + sentence_loss
batch_shape = tf.slice(tf.shape(sentence_labels), [0], [1]) else:
sentence_loss = None
loss = mask_label_loss
batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
# TODO(hongkuny): Avoids the hack and switches add_loss. # TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss) final_loss = tf.fill(batch_shape, loss)
...@@ -155,7 +167,8 @@ def get_transformer_encoder(bert_config, ...@@ -155,7 +167,8 @@ def get_transformer_encoder(bert_config,
def pretrain_model(bert_config, def pretrain_model(bert_config,
seq_length, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
initializer=None): initializer=None,
use_next_sentence_label=True):
"""Returns model to be used for pre-training. """Returns model to be used for pre-training.
Args: Args:
...@@ -164,6 +177,7 @@ def pretrain_model(bert_config, ...@@ -164,6 +177,7 @@ def pretrain_model(bert_config,
max_predictions_per_seq: Maximum number of tokens in sequence to mask out max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining. and use for pretraining.
initializer: Initializer for weights in BertPretrainer. initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
Returns: Returns:
Pretraining model as well as core BERT submodel from which to save Pretraining model as well as core BERT submodel from which to save
...@@ -185,8 +199,12 @@ def pretrain_model(bert_config, ...@@ -185,8 +199,12 @@ def pretrain_model(bert_config,
shape=(max_predictions_per_seq,), shape=(max_predictions_per_seq,),
name='masked_lm_weights', name='masked_lm_weights',
dtype=tf.int32) dtype=tf.int32)
if use_next_sentence_label:
next_sentence_labels = tf.keras.layers.Input( next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32) shape=(1,), name='next_sentence_labels', dtype=tf.int32)
else:
next_sentence_labels = None
transformer_encoder = get_transformer_encoder(bert_config, seq_length) transformer_encoder = get_transformer_encoder(bert_config, seq_length)
if initializer is None: if initializer is None:
...@@ -206,17 +224,18 @@ def pretrain_model(bert_config, ...@@ -206,17 +224,18 @@ def pretrain_model(bert_config,
vocab_size=bert_config.vocab_size) vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
masked_lm_weights, next_sentence_labels) masked_lm_weights, next_sentence_labels)
keras_model = tf.keras.Model( inputs = {
inputs={
'input_word_ids': input_word_ids, 'input_word_ids': input_word_ids,
'input_mask': input_mask, 'input_mask': input_mask,
'input_type_ids': input_type_ids, 'input_type_ids': input_type_ids,
'masked_lm_positions': masked_lm_positions, 'masked_lm_positions': masked_lm_positions,
'masked_lm_ids': masked_lm_ids, 'masked_lm_ids': masked_lm_ids,
'masked_lm_weights': masked_lm_weights, 'masked_lm_weights': masked_lm_weights,
'next_sentence_labels': next_sentence_labels, }
}, if use_next_sentence_label:
outputs=output_loss) inputs['next_sentence_labels'] = next_sentence_labels
keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
return keras_model, transformer_encoder return keras_model, transformer_encoder
...@@ -313,8 +332,7 @@ def classifier_model(bert_config, ...@@ -313,8 +332,7 @@ def classifier_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer( bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
hub_module_url, trainable=hub_module_trainable)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids]) pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)( output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output) pooled_output)
......
...@@ -59,7 +59,8 @@ def create_pretrain_dataset(input_patterns, ...@@ -59,7 +59,8 @@ def create_pretrain_dataset(input_patterns,
max_predictions_per_seq, max_predictions_per_seq,
batch_size, batch_size,
is_training=True, is_training=True,
input_pipeline_context=None): input_pipeline_context=None,
use_next_sentence_label=True):
"""Creates input dataset from (tf)records files for pretraining.""" """Creates input dataset from (tf)records files for pretraining."""
name_to_features = { name_to_features = {
'input_ids': 'input_ids':
...@@ -74,9 +75,10 @@ def create_pretrain_dataset(input_patterns, ...@@ -74,9 +75,10 @@ def create_pretrain_dataset(input_patterns,
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64), tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_weights': 'masked_lm_weights':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32), tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
'next_sentence_labels':
tf.io.FixedLenFeature([1], tf.int64),
} }
if use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64)
for input_pattern in input_patterns: for input_pattern in input_patterns:
if not tf.io.gfile.glob(input_pattern): if not tf.io.gfile.glob(input_pattern):
...@@ -118,8 +120,9 @@ def create_pretrain_dataset(input_patterns, ...@@ -118,8 +120,9 @@ def create_pretrain_dataset(input_patterns,
'masked_lm_positions': record['masked_lm_positions'], 'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'], 'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'], 'masked_lm_weights': record['masked_lm_weights'],
'next_sentence_labels': record['next_sentence_labels'],
} }
if use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels']
y = record['masked_lm_weights'] y = record['masked_lm_weights']
......
...@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000, ...@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.') 'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000, flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.') 'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
common_flags.define_gin_flags() common_flags.define_gin_flags()
...@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS ...@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS
def get_pretrain_dataset_fn(input_file_pattern, seq_length, def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, global_batch_size): max_predictions_per_seq, global_batch_size,
use_next_sentence_label=True):
"""Returns input dataset from input file string.""" """Returns input dataset from input file string."""
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining.""" """Returns tf.data.Dataset for distributed BERT pretraining."""
...@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length, ...@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
batch_size, batch_size,
is_training=True, is_training=True,
input_pipeline_context=ctx) input_pipeline_context=ctx,
use_next_sentence_label=use_next_sentence_label)
return train_dataset return train_dataset
return _dataset_fn return _dataset_fn
...@@ -95,17 +99,20 @@ def run_customized_training(strategy, ...@@ -95,17 +99,20 @@ def run_customized_training(strategy,
end_lr, end_lr,
optimizer_type, optimizer_type,
input_files, input_files,
train_batch_size): train_batch_size,
use_next_sentence_label=True):
"""Run BERT pretrain model training using low-level API.""" """Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_predictions_per_seq, max_predictions_per_seq,
train_batch_size) train_batch_size,
use_next_sentence_label)
def _get_pretrain_model(): def _get_pretrain_model():
"""Gets a pretraining model.""" """Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model( pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq) bert_config, max_seq_length, max_predictions_per_seq,
use_next_sentence_label=use_next_sentence_label)
optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps, initial_lr, steps_per_epoch * epochs, warmup_steps,
end_lr, optimizer_type) end_lr, optimizer_type)
...@@ -157,7 +164,8 @@ def run_bert_pretrain(strategy): ...@@ -157,7 +164,8 @@ def run_bert_pretrain(strategy):
FLAGS.end_lr, FLAGS.end_lr,
FLAGS.optimizer_type, FLAGS.optimizer_type,
FLAGS.input_files, FLAGS.input_files,
FLAGS.train_batch_size) FLAGS.train_batch_size,
FLAGS.use_next_sentence_label)
def main(_): def main(_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment