"examples/git@developer.sourcefind.cn:change/sglang.git" did not exist on "96a5e4dd795b675210b0d18f5e9fab69ec69bb6e"
Commit 3db5347d authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Allow to not use next sentence labels in pretraining.

PiperOrigin-RevId: 306324960
parent 1025682f
...@@ -54,29 +54,41 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -54,29 +54,41 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean') self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy( if sentence_labels is not None:
sentence_labels, sentence_output) next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
self.add_metric( sentence_labels, sentence_output)
next_sentence_accuracy, self.add_metric(
name='next_sentence_accuracy', next_sentence_accuracy,
aggregation='mean') name='next_sentence_accuracy',
aggregation='mean')
self.add_metric(
next_sentence_loss, name='next_sentence_loss', aggregation='mean') if next_sentence_loss is not None:
self.add_metric(
def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights, next_sentence_loss, name='next_sentence_loss', aggregation='mean')
sentence_labels):
def call(self,
lm_output,
sentence_output,
lm_label_ids,
lm_label_weights,
sentence_labels=None):
"""Implements call() for the layer.""" """Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32) lm_output = tf.cast(lm_output, tf.float32)
sentence_output = tf.cast(sentence_output, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output) if sentence_labels is not None:
loss = mask_label_loss + sentence_loss sentence_output = tf.cast(sentence_output, tf.float32)
batch_shape = tf.slice(tf.shape(sentence_labels), [0], [1]) sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output)
loss = mask_label_loss + sentence_loss
else:
sentence_loss = None
loss = mask_label_loss
batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
# TODO(hongkuny): Avoids the hack and switches add_loss. # TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss) final_loss = tf.fill(batch_shape, loss)
...@@ -155,7 +167,8 @@ def get_transformer_encoder(bert_config, ...@@ -155,7 +167,8 @@ def get_transformer_encoder(bert_config,
def pretrain_model(bert_config, def pretrain_model(bert_config,
seq_length, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
initializer=None): initializer=None,
use_next_sentence_label=True):
"""Returns model to be used for pre-training. """Returns model to be used for pre-training.
Args: Args:
...@@ -164,6 +177,7 @@ def pretrain_model(bert_config, ...@@ -164,6 +177,7 @@ def pretrain_model(bert_config,
max_predictions_per_seq: Maximum number of tokens in sequence to mask out max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining. and use for pretraining.
initializer: Initializer for weights in BertPretrainer. initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
Returns: Returns:
Pretraining model as well as core BERT submodel from which to save Pretraining model as well as core BERT submodel from which to save
...@@ -185,8 +199,12 @@ def pretrain_model(bert_config, ...@@ -185,8 +199,12 @@ def pretrain_model(bert_config,
shape=(max_predictions_per_seq,), shape=(max_predictions_per_seq,),
name='masked_lm_weights', name='masked_lm_weights',
dtype=tf.int32) dtype=tf.int32)
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32) if use_next_sentence_label:
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
else:
next_sentence_labels = None
transformer_encoder = get_transformer_encoder(bert_config, seq_length) transformer_encoder = get_transformer_encoder(bert_config, seq_length)
if initializer is None: if initializer is None:
...@@ -206,17 +224,18 @@ def pretrain_model(bert_config, ...@@ -206,17 +224,18 @@ def pretrain_model(bert_config,
vocab_size=bert_config.vocab_size) vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
masked_lm_weights, next_sentence_labels) masked_lm_weights, next_sentence_labels)
keras_model = tf.keras.Model( inputs = {
inputs={ 'input_word_ids': input_word_ids,
'input_word_ids': input_word_ids, 'input_mask': input_mask,
'input_mask': input_mask, 'input_type_ids': input_type_ids,
'input_type_ids': input_type_ids, 'masked_lm_positions': masked_lm_positions,
'masked_lm_positions': masked_lm_positions, 'masked_lm_ids': masked_lm_ids,
'masked_lm_ids': masked_lm_ids, 'masked_lm_weights': masked_lm_weights,
'masked_lm_weights': masked_lm_weights, }
'next_sentence_labels': next_sentence_labels, if use_next_sentence_label:
}, inputs['next_sentence_labels'] = next_sentence_labels
outputs=output_loss)
keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
return keras_model, transformer_encoder return keras_model, transformer_encoder
...@@ -313,8 +332,7 @@ def classifier_model(bert_config, ...@@ -313,8 +332,7 @@ def classifier_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer( bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
hub_module_url, trainable=hub_module_trainable)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids]) pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)( output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output) pooled_output)
......
...@@ -59,7 +59,8 @@ def create_pretrain_dataset(input_patterns, ...@@ -59,7 +59,8 @@ def create_pretrain_dataset(input_patterns,
max_predictions_per_seq, max_predictions_per_seq,
batch_size, batch_size,
is_training=True, is_training=True,
input_pipeline_context=None): input_pipeline_context=None,
use_next_sentence_label=True):
"""Creates input dataset from (tf)records files for pretraining.""" """Creates input dataset from (tf)records files for pretraining."""
name_to_features = { name_to_features = {
'input_ids': 'input_ids':
...@@ -74,9 +75,10 @@ def create_pretrain_dataset(input_patterns, ...@@ -74,9 +75,10 @@ def create_pretrain_dataset(input_patterns,
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64), tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_weights': 'masked_lm_weights':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32), tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
'next_sentence_labels':
tf.io.FixedLenFeature([1], tf.int64),
} }
if use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64)
for input_pattern in input_patterns: for input_pattern in input_patterns:
if not tf.io.gfile.glob(input_pattern): if not tf.io.gfile.glob(input_pattern):
...@@ -118,8 +120,9 @@ def create_pretrain_dataset(input_patterns, ...@@ -118,8 +120,9 @@ def create_pretrain_dataset(input_patterns,
'masked_lm_positions': record['masked_lm_positions'], 'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'], 'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'], 'masked_lm_weights': record['masked_lm_weights'],
'next_sentence_labels': record['next_sentence_labels'],
} }
if use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels']
y = record['masked_lm_weights'] y = record['masked_lm_weights']
......
...@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000, ...@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.') 'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000, flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.') 'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
common_flags.define_gin_flags() common_flags.define_gin_flags()
...@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS ...@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS
def get_pretrain_dataset_fn(input_file_pattern, seq_length, def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, global_batch_size): max_predictions_per_seq, global_batch_size,
use_next_sentence_label=True):
"""Returns input dataset from input file string.""" """Returns input dataset from input file string."""
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining.""" """Returns tf.data.Dataset for distributed BERT pretraining."""
...@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length, ...@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
batch_size, batch_size,
is_training=True, is_training=True,
input_pipeline_context=ctx) input_pipeline_context=ctx,
use_next_sentence_label=use_next_sentence_label)
return train_dataset return train_dataset
return _dataset_fn return _dataset_fn
...@@ -95,17 +99,20 @@ def run_customized_training(strategy, ...@@ -95,17 +99,20 @@ def run_customized_training(strategy,
end_lr, end_lr,
optimizer_type, optimizer_type,
input_files, input_files,
train_batch_size): train_batch_size,
use_next_sentence_label=True):
"""Run BERT pretrain model training using low-level API.""" """Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_predictions_per_seq, max_predictions_per_seq,
train_batch_size) train_batch_size,
use_next_sentence_label)
def _get_pretrain_model(): def _get_pretrain_model():
"""Gets a pretraining model.""" """Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model( pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq) bert_config, max_seq_length, max_predictions_per_seq,
use_next_sentence_label=use_next_sentence_label)
optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps, initial_lr, steps_per_epoch * epochs, warmup_steps,
end_lr, optimizer_type) end_lr, optimizer_type)
...@@ -157,7 +164,8 @@ def run_bert_pretrain(strategy): ...@@ -157,7 +164,8 @@ def run_bert_pretrain(strategy):
FLAGS.end_lr, FLAGS.end_lr,
FLAGS.optimizer_type, FLAGS.optimizer_type,
FLAGS.input_files, FLAGS.input_files,
FLAGS.train_batch_size) FLAGS.train_batch_size,
FLAGS.use_next_sentence_label)
def main(_): def main(_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment