Unverified Commit 965cc3ee authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #7 from tensorflow/master

updated
parents 1f3247f4 1f685c54
...@@ -32,9 +32,10 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -32,9 +32,10 @@ class ExportTfhubTest(tf.test.TestCase):
def test_export_tfhub(self): def test_export_tfhub(self):
# Exports a savedmodel for TF-Hub # Exports a savedmodel for TF-Hub
hidden_size = 16
bert_config = configs.BertConfig( bert_config = configs.BertConfig(
vocab_size=100, vocab_size=100,
hidden_size=16, hidden_size=hidden_size,
intermediate_size=32, intermediate_size=32,
max_position_embeddings=128, max_position_embeddings=128,
num_attention_heads=2, num_attention_heads=2,
...@@ -67,7 +68,8 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -67,7 +68,8 @@ class ExportTfhubTest(tf.test.TestCase):
hub_layer.trainable_weights): hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy()) self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
dummy_ids = np.zeros((2, 10), dtype=np.int32) seq_length = 10
dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids]) hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids]) source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
...@@ -75,13 +77,23 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -75,13 +77,23 @@ class ExportTfhubTest(tf.test.TestCase):
# while the outputs of encoder is in reversed order, i.e., # while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output". # "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids])) encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16)) self.assertEqual(hub_outputs[0].shape, (2, hidden_size))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16)) self.assertEqual(hub_outputs[1].shape, (2, seq_length, hidden_size))
for source_output, hub_output, encoder_output in zip( for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs): source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy()) self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy()) self.assertAllClose(source_output.numpy(), encoder_output.numpy())
# Test propagation of seq_length in shape inference.
input_word_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_mask = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
self.assertEqual(sequence_output.shape.as_list(),
[None, seq_length, hidden_size])
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -59,7 +59,9 @@ def create_pretrain_dataset(input_patterns, ...@@ -59,7 +59,9 @@ def create_pretrain_dataset(input_patterns,
max_predictions_per_seq, max_predictions_per_seq,
batch_size, batch_size,
is_training=True, is_training=True,
input_pipeline_context=None): input_pipeline_context=None,
use_next_sentence_label=True,
use_position_id=False):
"""Creates input dataset from (tf)records files for pretraining.""" """Creates input dataset from (tf)records files for pretraining."""
name_to_features = { name_to_features = {
'input_ids': 'input_ids':
...@@ -74,10 +76,13 @@ def create_pretrain_dataset(input_patterns, ...@@ -74,10 +76,13 @@ def create_pretrain_dataset(input_patterns,
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64), tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_weights': 'masked_lm_weights':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32), tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
'next_sentence_labels':
tf.io.FixedLenFeature([1], tf.int64),
} }
if use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64)
if use_position_id:
name_to_features['position_ids'] = tf.io.FixedLenFeature([seq_length],
tf.int64)
for input_pattern in input_patterns: for input_pattern in input_patterns:
if not tf.io.gfile.glob(input_pattern): if not tf.io.gfile.glob(input_pattern):
raise ValueError('%s does not match any files.' % input_pattern) raise ValueError('%s does not match any files.' % input_pattern)
...@@ -118,8 +123,11 @@ def create_pretrain_dataset(input_patterns, ...@@ -118,8 +123,11 @@ def create_pretrain_dataset(input_patterns,
'masked_lm_positions': record['masked_lm_positions'], 'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'], 'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'], 'masked_lm_weights': record['masked_lm_weights'],
'next_sentence_labels': record['next_sentence_labels'],
} }
if use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels']
if use_position_id:
x['position_ids'] = record['position_ids']
y = record['masked_lm_weights'] y = record['masked_lm_weights']
...@@ -148,7 +156,6 @@ def create_classifier_dataset(file_path, ...@@ -148,7 +156,6 @@ def create_classifier_dataset(file_path,
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64), 'label_ids': tf.io.FixedLenFeature([], tf.int64),
'is_real_example': tf.io.FixedLenFeature([], tf.int64),
} }
dataset = single_file_dataset(file_path, name_to_features) dataset = single_file_dataset(file_path, name_to_features)
......
...@@ -153,7 +153,8 @@ def run_customized_training_loop( ...@@ -153,7 +153,8 @@ def run_customized_training_loop(
`model_fn`. `model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`, training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training. `on_epoch_begin()`, `on_epoch_end()` methods are invoked during
training. Note that some metrics may be missing from `logs`.
run_eagerly: Whether to run model training in pure eager execution. This run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy. should be disable for TPUStrategy.
sub_model_export_name: If not None, will export `sub_model` returned by sub_model_export_name: If not None, will export `sub_model` returned by
...@@ -225,6 +226,8 @@ def run_customized_training_loop( ...@@ -225,6 +226,8 @@ def run_customized_training_loop(
raise ValueError( raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.') 'if `metric_fn` is specified, metric_fn must be a callable.')
callback_list = tf.keras.callbacks.CallbackList(custom_callbacks)
total_training_steps = steps_per_epoch * epochs total_training_steps = steps_per_epoch * epochs
train_iterator = _get_input_iterator(train_input_fn, strategy) train_iterator = _get_input_iterator(train_input_fn, strategy)
...@@ -361,37 +364,37 @@ def run_customized_training_loop( ...@@ -361,37 +364,37 @@ def run_customized_training_loop(
test_step = tf.function(test_step) test_step = tf.function(test_step)
def _run_evaluation(current_training_step, test_iterator): def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics.""" """Runs validation steps and aggregate metrics.
Args:
current_training_step: tf.int32 tensor containing the current step.
test_iterator: distributed iterator of test datasets.
Returns:
A dict of metic names and values.
"""
for _ in range(eval_steps): for _ in range(eval_steps):
test_step(test_iterator) test_step(test_iterator)
logs = {}
with eval_summary_writer.as_default(): with eval_summary_writer.as_default():
for metric in eval_metrics + model.metrics: for metric in eval_metrics + model.metrics:
metric_value = _float_metric_value(metric) metric_value = _float_metric_value(metric)
logs[metric.name] = metric_value
logging.info('Step: [%d] Validation %s = %f', current_training_step, logging.info('Step: [%d] Validation %s = %f', current_training_step,
metric.name, metric_value) metric.name, metric_value)
tf.summary.scalar( tf.summary.scalar(
metric.name, metric_value, step=current_training_step) metric.name, metric_value, step=current_training_step)
eval_summary_writer.flush() eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch): return logs
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch, logs):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch, logs)
# Training loop starts here. # Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) checkpoint = tf.train.Checkpoint(
model=model, optimizer=optimizer, global_step=optimizer.iterations)
sub_model_checkpoint = tf.train.Checkpoint( sub_model_checkpoint = tf.train.Checkpoint(
model=sub_model) if sub_model_export_name else None model=sub_model,
global_step=optimizer.iterations) if sub_model_export_name else None
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file: if latest_checkpoint_file:
...@@ -405,13 +408,16 @@ def run_customized_training_loop( ...@@ -405,13 +408,16 @@ def run_customized_training_loop(
checkpoint_name = 'ctl_step_{step}.ckpt' checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps: while current_step < total_training_steps:
if current_step % steps_per_epoch == 0:
callback_list.on_epoch_begin(int(current_step / steps_per_epoch) + 1)
# Training loss/metric are taking average over steps inside micro # Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round. # training loop. We reset the their values before each round.
train_loss_metric.reset_states() train_loss_metric.reset_states()
for metric in train_metrics + model.metrics: for metric in train_metrics + model.metrics:
metric.reset_states() metric.reset_states()
_run_callbacks_on_batch_begin(current_step) callback_list.on_batch_begin(current_step)
# Runs several steps in the host while loop. # Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
...@@ -426,7 +432,7 @@ def run_customized_training_loop( ...@@ -426,7 +432,7 @@ def run_customized_training_loop(
tf.convert_to_tensor(steps, dtype=tf.int32)) tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
current_step += steps current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss}) callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
...@@ -443,35 +449,43 @@ def run_customized_training_loop( ...@@ -443,35 +449,43 @@ def run_customized_training_loop(
train_summary_writer.flush() train_summary_writer.flush()
logging.info(training_status) logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0: if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last # Save a submodel with the step in the file name after each epoch.
# step of training. if sub_model_export_name:
_save_checkpoint(
strategy, sub_model_checkpoint, model_dir,
'%s_step_%d.ckpt' % (sub_model_export_name, current_step))
# Save model checkpoints and run validation steps after each epoch
# (with the exception of the final epoch which is handled after the
# training loop).
if current_step < total_training_steps: if current_step < total_training_steps:
_save_checkpoint(strategy, checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
if sub_model_export_name: logs = None
_save_checkpoint( if eval_input_fn:
strategy, sub_model_checkpoint, model_dir, logging.info('Running evaluation after step: %s.', current_step)
'%s_step_%d.ckpt' % (sub_model_export_name, current_step)) logs = _run_evaluation(current_step,
if eval_input_fn: _get_input_iterator(eval_input_fn, strategy))
logging.info('Running evaluation after step: %s.', current_step) # Re-initialize evaluation metric.
_run_evaluation(current_step, for metric in eval_metrics + model.metrics:
_get_input_iterator(eval_input_fn, strategy)) metric.reset_states()
# Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics: callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
metric.reset_states()
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint(strategy, sub_model_checkpoint, model_dir, _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
'%s.ckpt' % sub_model_export_name) '%s.ckpt' % sub_model_export_name)
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn: if eval_input_fn:
logging.info('Running final evaluation after training is complete.') logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step, logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy)) _get_input_iterator(eval_input_fn, strategy))
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
training_summary = { training_summary = {
'total_training_steps': total_training_steps, 'total_training_steps': total_training_steps,
......
...@@ -135,6 +135,27 @@ def check_eventfile_for_keyword(keyword, summary_dir): ...@@ -135,6 +135,27 @@ def check_eventfile_for_keyword(keyword, summary_dir):
return any(summaries_with_matching_keyword(keyword, summary_dir)) return any(summaries_with_matching_keyword(keyword, summary_dir))
class RecordingCallback(tf.keras.callbacks.Callback):
def __init__(self):
self.batch_begin = [] # (batch, logs)
self.batch_end = [] # (batch, logs)
self.epoch_begin = [] # (epoch, logs)
self.epoch_end = [] # (epoch, logs)
def on_batch_begin(self, batch, logs=None):
self.batch_begin.append((batch, logs))
def on_batch_end(self, batch, logs=None):
self.batch_end.append((batch, logs))
def on_epoch_begin(self, epoch, logs=None):
self.epoch_begin.append((epoch, logs))
def on_epoch_end(self, epoch, logs=None):
self.epoch_end.append((epoch, logs))
class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
...@@ -156,6 +177,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -156,6 +177,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
eval_input_fn=input_fn, eval_input_fn=input_fn,
eval_steps=10, eval_steps=10,
init_checkpoint=None, init_checkpoint=None,
sub_model_export_name='my_submodel_name',
metric_fn=metric_fn, metric_fn=metric_fn,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=run_eagerly) run_eagerly=run_eagerly)
...@@ -188,7 +210,20 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -188,7 +210,20 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
distribution, model_dir, steps_per_loop=10, run_eagerly=False) distribution, model_dir, steps_per_loop=10, run_eagerly=False)
# Two checkpoints should be saved after two epochs. # Two checkpoints should be saved after two epochs.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*'))) files = map(os.path.basename,
tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*index')))
self.assertCountEqual(['ctl_step_20.ckpt-1.index',
'ctl_step_40.ckpt-2.index'], files)
# Three submodel checkpoints should be saved after two epochs (one after
# each epoch plus one final).
files = map(os.path.basename,
tf.io.gfile.glob(os.path.join(model_dir,
'my_submodel_name*index')))
self.assertCountEqual(['my_submodel_name.ckpt-3.index',
'my_submodel_name_step_20.ckpt-1.index',
'my_submodel_name_step_40.ckpt-2.index'], files)
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.glob( tf.io.gfile.glob(
os.path.join(model_dir, 'summaries/training_summary*'))) os.path.join(model_dir, 'summaries/training_summary*')))
...@@ -210,6 +245,41 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -210,6 +245,41 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword('mean_input', check_eventfile_for_keyword('mean_input',
os.path.join(model_dir, 'summaries/eval'))) os.path.join(model_dir, 'summaries/eval')))
@combinations.generate(eager_strategy_combinations())
def test_train_check_callbacks(self, distribution):
model_dir = self.get_temp_dir()
callback = RecordingCallback()
callbacks = [callback]
input_fn = create_fake_data_input_fn(
batch_size=8, features_shape=[128], num_classes=3)
model_training_utils.run_customized_training_loop(
strategy=distribution,
model_fn=self._model_fn,
loss_fn=tf.keras.losses.categorical_crossentropy,
model_dir=model_dir,
steps_per_epoch=20,
steps_per_loop=10,
epochs=2,
train_input_fn=input_fn,
eval_input_fn=input_fn,
eval_steps=10,
init_checkpoint=None,
metric_fn=metric_fn,
custom_callbacks=callbacks,
run_eagerly=False)
self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})])
epoch_ends, epoch_end_infos = zip(*callback.epoch_end)
self.assertEqual(list(epoch_ends), [1, 2])
for info in epoch_end_infos:
self.assertIn('accuracy', info)
self.assertEqual(callback.batch_begin,
[(0, {}), (10, {}), (20, {}), (30, {})])
batch_ends, batch_end_infos = zip(*callback.batch_end)
self.assertEqual(list(batch_ends), [9, 19, 29, 39])
for info in batch_end_infos:
self.assertIn('loss', info)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
......
...@@ -56,6 +56,7 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') ...@@ -56,6 +56,7 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.') flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -85,7 +86,7 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size, ...@@ -85,7 +86,7 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
batch_size = ctx.get_per_replica_batch_size( batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_classifier_dataset( dataset = input_pipeline.create_classifier_dataset(
input_file_pattern, tf.io.gfile.glob(input_file_pattern),
max_seq_length, max_seq_length,
batch_size, batch_size,
is_training=is_training, is_training=is_training,
...@@ -125,7 +126,8 @@ def run_bert_classifier(strategy, ...@@ -125,7 +126,8 @@ def run_bert_classifier(strategy,
hub_module_url=FLAGS.hub_module_url, hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)) hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.end_lr, FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer( classifier_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
...@@ -196,7 +198,7 @@ def run_keras_compile_fit(model_dir, ...@@ -196,7 +198,7 @@ def run_keras_compile_fit(model_dir,
with strategy.scope(): with strategy.scope():
training_dataset = train_input_fn() training_dataset = train_input_fn()
evaluation_dataset = eval_input_fn() evaluation_dataset = eval_input_fn() if eval_input_fn else None
bert_model, sub_model = model_fn() bert_model, sub_model = model_fn()
optimizer = bert_model.optimizer optimizer = bert_model.optimizer
...@@ -329,7 +331,9 @@ def run_bert(strategy, ...@@ -329,7 +331,9 @@ def run_bert(strategy,
input_meta_data, input_meta_data,
model_config, model_config,
train_input_fn=None, train_input_fn=None,
eval_input_fn=None): eval_input_fn=None,
init_checkpoint=None,
custom_callbacks=None):
"""Run BERT training.""" """Run BERT training."""
if FLAGS.mode == 'export_only': if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API # As Keras ModelCheckpoint callback used with Keras compile/fit() API
...@@ -356,14 +360,14 @@ def run_bert(strategy, ...@@ -356,14 +360,14 @@ def run_bert(strategy,
if not strategy: if not strategy:
raise ValueError('Distribution strategy has not been specified.') raise ValueError('Distribution strategy has not been specified.')
if not custom_callbacks:
custom_callbacks = []
if FLAGS.log_steps: if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory( custom_callbacks.append(keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size, batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir, logdir=FLAGS.model_dir))
)]
else:
custom_callbacks = None
trained_model = run_bert_classifier( trained_model = run_bert_classifier(
strategy, strategy,
...@@ -376,7 +380,7 @@ def run_bert(strategy, ...@@ -376,7 +380,7 @@ def run_bert(strategy,
eval_steps, eval_steps,
warmup_steps, warmup_steps,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.init_checkpoint, init_checkpoint or FLAGS.init_checkpoint,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
...@@ -394,9 +398,12 @@ def run_bert(strategy, ...@@ -394,9 +398,12 @@ def run_bert(strategy,
return trained_model return trained_model
def main(_): def custom_main(custom_callbacks=None):
# Users should always run this script under TF 2.x """Run classification.
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
"""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
...@@ -421,7 +428,12 @@ def main(_): ...@@ -421,7 +428,12 @@ def main(_):
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_bert(strategy, input_meta_data, bert_config, train_input_fn, run_bert(strategy, input_meta_data, bert_config, train_input_fn,
eval_input_fn) eval_input_fn, custom_callbacks=custom_callbacks)
def main(_):
# Users should always run this script under TF 2.x
custom_main(custom_callbacks=None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000, ...@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.') 'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000, flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.') 'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
common_flags.define_gin_flags() common_flags.define_gin_flags()
...@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS ...@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS
def get_pretrain_dataset_fn(input_file_pattern, seq_length, def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, global_batch_size): max_predictions_per_seq, global_batch_size,
use_next_sentence_label=True):
"""Returns input dataset from input file string.""" """Returns input dataset from input file string."""
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining.""" """Returns tf.data.Dataset for distributed BERT pretraining."""
...@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length, ...@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
batch_size, batch_size,
is_training=True, is_training=True,
input_pipeline_context=ctx) input_pipeline_context=ctx,
use_next_sentence_label=use_next_sentence_label)
return train_dataset return train_dataset
return _dataset_fn return _dataset_fn
...@@ -92,20 +96,26 @@ def run_customized_training(strategy, ...@@ -92,20 +96,26 @@ def run_customized_training(strategy,
epochs, epochs,
initial_lr, initial_lr,
warmup_steps, warmup_steps,
end_lr,
optimizer_type,
input_files, input_files,
train_batch_size): train_batch_size,
use_next_sentence_label=True):
"""Run BERT pretrain model training using low-level API.""" """Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_predictions_per_seq, max_predictions_per_seq,
train_batch_size) train_batch_size,
use_next_sentence_label)
def _get_pretrain_model(): def _get_pretrain_model():
"""Gets a pretraining model.""" """Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model( pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq) bert_config, max_seq_length, max_predictions_per_seq,
use_next_sentence_label=use_next_sentence_label)
optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps,
end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer( pretrain_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
...@@ -151,8 +161,11 @@ def run_bert_pretrain(strategy): ...@@ -151,8 +161,11 @@ def run_bert_pretrain(strategy):
FLAGS.num_train_epochs, FLAGS.num_train_epochs,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.warmup_steps, FLAGS.warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type,
FLAGS.input_files, FLAGS.input_files,
FLAGS.train_batch_size) FLAGS.train_batch_size,
FLAGS.use_next_sentence_label)
def main(_): def main(_):
......
...@@ -20,7 +20,6 @@ from __future__ import print_function ...@@ -20,7 +20,6 @@ from __future__ import print_function
import json import json
import os import os
import tempfile
import time import time
from absl import app from absl import app
...@@ -48,11 +47,13 @@ FLAGS = flags.FLAGS ...@@ -48,11 +47,13 @@ FLAGS = flags.FLAGS
def train_squad(strategy, def train_squad(strategy,
input_meta_data, input_meta_data,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=False): run_eagerly=False,
init_checkpoint=None):
"""Run bert squad training.""" """Run bert squad training."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
run_squad_helper.train_squad(strategy, input_meta_data, bert_config, run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly) custom_callbacks, run_eagerly, init_checkpoint)
def predict_squad(strategy, input_meta_data): def predict_squad(strategy, input_meta_data):
...@@ -130,18 +131,15 @@ def main(_): ...@@ -130,18 +131,15 @@ def main(_):
eval_metrics = eval_squad(strategy, input_meta_data) eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['final_f1'] f1_score = eval_metrics['final_f1']
logging.info('SQuAD eval F1-score: %f', f1_score) logging.info('SQuAD eval F1-score: %f', f1_score)
if (not strategy) or strategy.extended.should_save_summary: summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
summary_dir = os.path.join(FLAGS.model_dir, 'summaries') summary_writer = tf.summary.create_file_writer(summary_dir)
else:
summary_dir = tempfile.mkdtemp()
summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval'))
with summary_writer.as_default(): with summary_writer.as_default():
# TODO(lehou): write to the correct step number. # TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0) tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush() summary_writer.flush()
# Wait for some time, for the depending mldash/tensorboard jobs to finish # Also write eval_metrics to json file.
# exporting the final F1-score. squad_lib_wp.write_to_json_files(
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
time.sleep(60) time.sleep(60)
......
...@@ -88,6 +88,7 @@ def define_common_squad_flags(): ...@@ -88,6 +88,7 @@ def define_common_squad_flags():
'another.') 'another.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -159,8 +160,12 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size, ...@@ -159,8 +160,12 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
return _dataset_fn return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config, def predict_squad_customized(strategy,
predict_tfrecord_path, num_steps): input_meta_data,
bert_config,
checkpoint_path,
predict_tfrecord_path,
num_steps):
"""Make predictions using a Bert-based squad model.""" """Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn( predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path, predict_tfrecord_path,
...@@ -179,7 +184,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -179,7 +184,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url) hub_module_url=FLAGS.hub_module_url)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) if checkpoint_path is None:
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path) logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model) checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial() checkpoint.restore(checkpoint_path).expect_partial()
...@@ -215,7 +221,8 @@ def train_squad(strategy, ...@@ -215,7 +221,8 @@ def train_squad(strategy,
input_meta_data, input_meta_data,
bert_config, bert_config,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=False): run_eagerly=False,
init_checkpoint=None):
"""Run bert squad training.""" """Run bert squad training."""
if strategy: if strategy:
logging.info('Training using customized training loop with distribution' logging.info('Training using customized training loop with distribution'
...@@ -244,7 +251,9 @@ def train_squad(strategy, ...@@ -244,7 +251,9 @@ def train_squad(strategy,
hub_module_trainable=FLAGS.hub_module_trainable) hub_module_trainable=FLAGS.hub_module_trainable)
optimizer = optimization.create_optimizer(FLAGS.learning_rate, optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs, steps_per_epoch * epochs,
warmup_steps) warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type)
squad_model.optimizer = performance.configure_optimizer( squad_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
...@@ -270,7 +279,7 @@ def train_squad(strategy, ...@@ -270,7 +279,7 @@ def train_squad(strategy,
steps_per_loop=FLAGS.steps_per_loop, steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs, epochs=epochs,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=init_checkpoint or FLAGS.init_checkpoint,
run_eagerly=run_eagerly, run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks, custom_callbacks=custom_callbacks,
explicit_allreduce=False, explicit_allreduce=False,
...@@ -278,7 +287,7 @@ def train_squad(strategy, ...@@ -278,7 +287,7 @@ def train_squad(strategy,
def prediction_output_squad( def prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib): strategy, input_meta_data, tokenizer, bert_config, squad_lib, checkpoint):
"""Makes predictions for a squad dataset.""" """Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride'] doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length'] max_query_length = input_meta_data['max_query_length']
...@@ -326,8 +335,9 @@ def prediction_output_squad( ...@@ -326,8 +335,9 @@ def prediction_output_squad(
logging.info(' Batch size = %d', FLAGS.predict_batch_size) logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size) num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config, all_results = predict_squad_customized(
eval_writer.filename, num_steps) strategy, input_meta_data, bert_config,
checkpoint, eval_writer.filename, num_steps)
all_predictions, all_nbest_json, scores_diff_json = ( all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output( squad_lib.postprocess_output(
...@@ -359,18 +369,34 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json, ...@@ -359,18 +369,34 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file) squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib): def predict_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them to hard drive.""" """Get prediction results and evaluate them to hard drive."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib) strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False)) input_meta_data.get('version_2_with_negative', False))
def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib): def eval_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them against ground truth.""" """Get prediction results and evaluate them against ground truth."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib) strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False)) input_meta_data.get('version_2_with_negative', False))
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import collections import collections
import csv import csv
import importlib
import os import os
from absl import logging from absl import logging
...@@ -403,6 +404,7 @@ class TfdsProcessor(DataProcessor): ...@@ -403,6 +404,7 @@ class TfdsProcessor(DataProcessor):
(TFDS) for the meaning of individual parameters): (TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number). dataset: Required dataset name (potentially with subset and version number).
data_dir: Optional TFDS source root directory. data_dir: Optional TFDS source root directory.
module_import: Optional Dataset module to import.
train_split: Name of the train split (defaults to `train`). train_split: Name of the train split (defaults to `train`).
dev_split: Name of the dev split (defaults to `validation`). dev_split: Name of the dev split (defaults to `validation`).
test_split: Name of the test split (defaults to `test`). test_split: Name of the test split (defaults to `test`).
...@@ -418,6 +420,9 @@ class TfdsProcessor(DataProcessor): ...@@ -418,6 +420,9 @@ class TfdsProcessor(DataProcessor):
process_text_fn=tokenization.convert_to_unicode): process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn) super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params) self._process_tfds_params_str(tfds_params)
if self.module_import:
importlib.import_module(self.module_import)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir, self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir,
with_info=True) with_info=True)
self._labels = list(range(info.features[self.label_key].num_classes)) self._labels = list(range(info.features[self.label_key].num_classes))
...@@ -428,6 +433,7 @@ class TfdsProcessor(DataProcessor): ...@@ -428,6 +433,7 @@ class TfdsProcessor(DataProcessor):
d = {k.strip(): v.strip() for k, v in tuples} d = {k.strip(): v.strip() for k, v in tuples}
self.dataset_name = d["dataset"] # Required. self.dataset_name = d["dataset"] # Required.
self.data_dir = d.get("data_dir", None) self.data_dir = d.get("data_dir", None)
self.module_import = d.get("module_import", None)
self.train_split = d.get("train_split", "train") self.train_split = d.get("train_split", "train")
self.dev_split = d.get("dev_split", "validation") self.dev_split = d.get("dev_split", "validation")
self.test_split = d.get("test_split", "test") self.test_split = d.get("test_split", "test")
...@@ -578,6 +584,7 @@ def file_based_convert_examples_to_features(examples, label_list, ...@@ -578,6 +584,7 @@ def file_based_convert_examples_to_features(examples, label_list,
output_file): output_file):
"""Convert a set of `InputExample`s to a TFRecord file.""" """Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file) writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples): for (ex_index, example) in enumerate(examples):
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import functools import functools
import json import json
import os
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -191,6 +192,7 @@ def main(_): ...@@ -191,6 +192,7 @@ def main(_):
else: else:
input_meta_data = generate_squad_dataset() input_meta_data = generate_squad_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer: with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
writer.write(json.dumps(input_meta_data, indent=4) + "\n") writer.write(json.dumps(input_meta_data, indent=4) + "\n")
......
...@@ -100,13 +100,14 @@ class TrainingInstance(object): ...@@ -100,13 +100,14 @@ class TrainingInstance(object):
def write_instance_to_example_files(instances, tokenizer, max_seq_length, def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files): max_predictions_per_seq, output_files,
gzip_compress):
"""Create TF example files from `TrainingInstance`s.""" """Create TF example files from `TrainingInstance`s."""
writers = [] writers = []
for output_file in output_files: for output_file in output_files:
writers.append( writers.append(
tf.io.TFRecordWriter( tf.io.TFRecordWriter(
output_file, options="GZIP" if FLAGS.gzip_compress else "")) output_file, options="GZIP" if gzip_compress else ""))
writer_index = 0 writer_index = 0
...@@ -183,9 +184,15 @@ def create_float_feature(values): ...@@ -183,9 +184,15 @@ def create_float_feature(values):
return feature return feature
def create_training_instances(input_files, tokenizer, max_seq_length, def create_training_instances(input_files,
dupe_factor, short_seq_prob, masked_lm_prob, tokenizer,
max_predictions_per_seq, rng): max_seq_length,
dupe_factor,
short_seq_prob,
masked_lm_prob,
max_predictions_per_seq,
rng,
do_whole_word_mask=False):
"""Create `TrainingInstance`s from raw text.""" """Create `TrainingInstance`s from raw text."""
all_documents = [[]] all_documents = [[]]
...@@ -221,7 +228,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length, ...@@ -221,7 +228,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
instances.extend( instances.extend(
create_instances_from_document( create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask))
rng.shuffle(instances) rng.shuffle(instances)
return instances return instances
...@@ -229,7 +237,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length, ...@@ -229,7 +237,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
def create_instances_from_document( def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng): masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False):
"""Creates `TrainingInstance`s for a single document.""" """Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index] document = all_documents[document_index]
...@@ -327,7 +336,8 @@ def create_instances_from_document( ...@@ -327,7 +336,8 @@ def create_instances_from_document(
(tokens, masked_lm_positions, (tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions( masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)
instance = TrainingInstance( instance = TrainingInstance(
tokens=tokens, tokens=tokens,
segment_ids=segment_ids, segment_ids=segment_ids,
...@@ -347,7 +357,8 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ...@@ -347,7 +357,8 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
def create_masked_lm_predictions(tokens, masked_lm_prob, def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng): max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates the predictions for the masked LM objective.""" """Creates the predictions for the masked LM objective."""
cand_indexes = [] cand_indexes = []
...@@ -363,7 +374,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob, ...@@ -363,7 +374,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
# Note that Whole Word Masking does *not* change the training code # Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed # at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary. # over the entire vocabulary.
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")): token.startswith("##")):
cand_indexes[-1].append(i) cand_indexes[-1].append(i)
else: else:
...@@ -456,7 +467,7 @@ def main(_): ...@@ -456,7 +467,7 @@ def main(_):
instances = create_training_instances( instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng) rng, FLAGS.do_whole_word_mask)
output_files = FLAGS.output_file.split(",") output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***") logging.info("*** Writing to output files ***")
...@@ -464,7 +475,8 @@ def main(_): ...@@ -464,7 +475,8 @@ def main(_):
logging.info(" %s", output_file) logging.info(" %s", output_file)
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files) FLAGS.max_predictions_per_seq, output_files,
FLAGS.gzip_compress)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -14,9 +14,15 @@ If `from_tensor` and `to_tensor` are the same, then this is self-attention. ...@@ -14,9 +14,15 @@ If `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache used * [CachedAttention](attention.py) implements an attention layer with cache used
for auto-agressive decoding. for auto-agressive decoding.
* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
heads attention, as decribed in ["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
* [Transformer](transformer.py) implements an optionally masked transformer as * [Transformer](transformer.py) implements an optionally masked transformer as
described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with ReZero
described in ["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
* [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding lookups designed for TPU-based models. * [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding lookups designed for TPU-based models.
* [PositionalEmbedding](position_embedding.py) creates a positional embedding * [PositionalEmbedding](position_embedding.py) creates a positional embedding
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,6 +18,8 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum ...@@ -18,6 +18,8 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.transformer import Transformer from official.nlp.modeling.layers.transformer import Transformer
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
...@@ -186,15 +186,9 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -186,15 +186,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
class CachedAttention(MultiHeadAttention): class CachedAttention(MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding. """Attention layer with cache used for auto-agressive decoding.
Arguments: Arguments are the same as `MultiHeadAttention` layer.
num_heads: Number of attention heads.
head_size: Size of each attention head.
**kwargs: Other keyword arguments inherit from `Attention` class.
""" """
def __init__(self, num_heads, head_size, **kwargs):
super(CachedAttention, self).__init__(num_heads, head_size, **kwargs)
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step): def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors.""" """Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values. # Combines cached keys and values with new keys and values.
......
...@@ -147,6 +147,8 @@ class DenseEinsum(tf.keras.layers.Layer): ...@@ -147,6 +147,8 @@ class DenseEinsum(tf.keras.layers.Layer):
config = { config = {
"output_shape": "output_shape":
self._output_shape, self._output_shape,
"num_summed_dimensions":
self._num_summed_dimensions,
"activation": "activation":
tf.keras.activations.serialize(self._activation), tf.keras.activations.serialize(self._activation),
"use_bias": "use_bias":
......
...@@ -21,8 +21,6 @@ from __future__ import print_function ...@@ -21,8 +21,6 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class OnDeviceEmbedding(tf.keras.layers.Layer): class OnDeviceEmbedding(tf.keras.layers.Layer):
...@@ -78,8 +76,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -78,8 +76,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
super(OnDeviceEmbedding, self).build(input_shape) super(OnDeviceEmbedding, self).build(input_shape)
def call(self, inputs): def call(self, inputs):
input_shape = tf_utils.get_shape_list(inputs, expected_rank=2)
input_shape.append(self._embedding_width)
flat_inputs = tf.reshape(inputs, [-1]) flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot: if self._use_one_hot:
one_hot_data = tf.one_hot( one_hot_data = tf.one_hot(
...@@ -87,6 +83,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -87,6 +83,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embeddings = tf.matmul(one_hot_data, self.embeddings) embeddings = tf.matmul(one_hot_data, self.embeddings)
else: else:
embeddings = tf.gather(self.embeddings, flat_inputs) embeddings = tf.gather(self.embeddings, flat_inputs)
embeddings = tf.reshape(embeddings, input_shape) embeddings = tf.reshape(
embeddings,
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
return embeddings return embeddings
...@@ -111,15 +111,10 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -111,15 +111,10 @@ class PositionEmbedding(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
if self._use_dynamic_slicing: if self._use_dynamic_slicing:
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3) position_embeddings = self._position_embeddings[:input_shape[1], :]
seq_length = input_shape[1]
width = input_shape[2]
position_embeddings = tf.expand_dims(
tf.slice(self._position_embeddings, [0, 0], [seq_length, width]),
axis=0)
else: else:
position_embeddings = tf.expand_dims(self._position_embeddings, axis=0) position_embeddings = self._position_embeddings
return position_embeddings return tf.broadcast_to(position_embeddings, input_shape)
...@@ -40,7 +40,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase): ...@@ -40,7 +40,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# When using static positional embedding shapes, the output is expected # When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch. # to be the same as the input shape in all dimensions save batch.
expected_output_shape = [1, sequence_length, width] expected_output_shape = [None, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list()) self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32. # The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype) self.assertEqual(tf.float32, output_tensor.dtype)
...@@ -55,7 +55,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase): ...@@ -55,7 +55,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# When using static positional embedding shapes, the output is expected # When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch. # to be the same as the input shape in all dimensions save batch.
expected_output_shape = [1, sequence_length, width] expected_output_shape = [None, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list()) self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32. # The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float16, output_tensor.dtype) self.assertEqual(tf.float16, output_tensor.dtype)
...@@ -72,7 +72,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase): ...@@ -72,7 +72,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# When using dynamic positional embedding shapes, the output is expected # When using dynamic positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions - but may be None if # to be the same as the input shape in all dimensions - but may be None if
# the input shape is None there. # the input shape is None there.
expected_output_shape = [1, None, width] expected_output_shape = [None, None, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list()) self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
def test_dynamic_layer_slicing(self): def test_dynamic_layer_slicing(self):
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based rezero-transformer block layer (Transformer with ReZero)."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class ReZeroTransformer(tf.keras.layers.Layer):
"""Transformer layer with ReZero.
This layer implements the Transformer from "Attention Is All You Need".
(https://arxiv.org/abs/1706.03762).
The residual connection implements the ReZero method.
(https://arxiv.org/abs/2003.04887)
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_layer_norm: If add layer_norm on top of the ReZero.
"""
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_layer_norm=False,
**kwargs):
super(ReZeroTransformer, self).__init__(**kwargs)
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_layer_norm = use_layer_norm
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerLayer, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
head_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
self._attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output")
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation)
self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
self._rezero_a = self.add_weight(
name="rezero_alpha",
initializer=tf.keras.initializers.Zeros(),
trainable=True, dtype=tf.float32)
super(ReZeroTransformer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"use_layer_norm":
self._use_layer_norm,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
}
base_config = super(ReZeroTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def reset_rezero(self):
self._rezero_a.assign(0.)
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None:
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output)
attention_output = input_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
attention_output = self._attention_layer_norm(attention_output)
else:
attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
layer_output = attention_output + tf.cast(self._rezero_a * layer_output,
tf.float32)
if self._use_layer_norm:
layer_output = self._output_layer_norm(layer_output)
return layer_output
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based rezero-transformer block layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import rezero_transformer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerWithReZeroLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
def test_layer_invocation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_rezero_without_layer_norm(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
use_layer_norm=False)
input_length, width = 16, 30
input_tensor = tf.keras.Input(shape=(input_length, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_data = np.random.rand(2, input_length, width)
test_layer._rezero_a.assign(1.0)
test_layer.reset_rezero()
output_data = model.predict(input_data)
self.assertAllClose(input_data, output_data)
def test_rezero_with_layer_norm(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
use_layer_norm=True)
input_length, width = 16, 30
input_tensor = tf.keras.Input(shape=(input_length, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_data = np.random.rand(2, input_length, width) + 2.0
output_data = model.predict(input_data)
input_data_normed = (
input_data - np.mean(input_data, axis=-1, keepdims=True)) / (
np.std(input_data, axis=-1, keepdims=True))
self.assertAllClose(input_data_normed, output_data)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment