Unverified Commit 965cc3ee authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #7 from tensorflow/master

updated
parents 1f3247f4 1f685c54
......@@ -32,9 +32,10 @@ class ExportTfhubTest(tf.test.TestCase):
def test_export_tfhub(self):
# Exports a savedmodel for TF-Hub
hidden_size = 16
bert_config = configs.BertConfig(
vocab_size=100,
hidden_size=16,
hidden_size=hidden_size,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
......@@ -67,7 +68,8 @@ class ExportTfhubTest(tf.test.TestCase):
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
dummy_ids = np.zeros((2, 10), dtype=np.int32)
seq_length = 10
dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
......@@ -75,13 +77,23 @@ class ExportTfhubTest(tf.test.TestCase):
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
self.assertEqual(hub_outputs[0].shape, (2, hidden_size))
self.assertEqual(hub_outputs[1].shape, (2, seq_length, hidden_size))
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
# Test propagation of seq_length in shape inference.
input_word_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_mask = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
self.assertEqual(sequence_output.shape.as_list(),
[None, seq_length, hidden_size])
if __name__ == "__main__":
tf.test.main()
......@@ -59,7 +59,9 @@ def create_pretrain_dataset(input_patterns,
max_predictions_per_seq,
batch_size,
is_training=True,
input_pipeline_context=None):
input_pipeline_context=None,
use_next_sentence_label=True,
use_position_id=False):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
'input_ids':
......@@ -74,10 +76,13 @@ def create_pretrain_dataset(input_patterns,
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_weights':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
'next_sentence_labels':
tf.io.FixedLenFeature([1], tf.int64),
}
if use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64)
if use_position_id:
name_to_features['position_ids'] = tf.io.FixedLenFeature([seq_length],
tf.int64)
for input_pattern in input_patterns:
if not tf.io.gfile.glob(input_pattern):
raise ValueError('%s does not match any files.' % input_pattern)
......@@ -118,8 +123,11 @@ def create_pretrain_dataset(input_patterns,
'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'],
'next_sentence_labels': record['next_sentence_labels'],
}
if use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels']
if use_position_id:
x['position_ids'] = record['position_ids']
y = record['masked_lm_weights']
......@@ -148,7 +156,6 @@ def create_classifier_dataset(file_path,
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64),
'is_real_example': tf.io.FixedLenFeature([], tf.int64),
}
dataset = single_file_dataset(file_path, name_to_features)
......
......@@ -153,7 +153,8 @@ def run_customized_training_loop(
`model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
`on_epoch_begin()`, `on_epoch_end()` methods are invoked during
training. Note that some metrics may be missing from `logs`.
run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy.
sub_model_export_name: If not None, will export `sub_model` returned by
......@@ -225,6 +226,8 @@ def run_customized_training_loop(
raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.')
callback_list = tf.keras.callbacks.CallbackList(custom_callbacks)
total_training_steps = steps_per_epoch * epochs
train_iterator = _get_input_iterator(train_input_fn, strategy)
......@@ -361,37 +364,37 @@ def run_customized_training_loop(
test_step = tf.function(test_step)
def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics."""
"""Runs validation steps and aggregate metrics.
Args:
current_training_step: tf.int32 tensor containing the current step.
test_iterator: distributed iterator of test datasets.
Returns:
A dict of metic names and values.
"""
for _ in range(eval_steps):
test_step(test_iterator)
logs = {}
with eval_summary_writer.as_default():
for metric in eval_metrics + model.metrics:
metric_value = _float_metric_value(metric)
logs[metric.name] = metric_value
logging.info('Step: [%d] Validation %s = %f', current_training_step,
metric.name, metric_value)
tf.summary.scalar(
metric.name, metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch, logs):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch, logs)
return logs
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint = tf.train.Checkpoint(
model=model, optimizer=optimizer, global_step=optimizer.iterations)
sub_model_checkpoint = tf.train.Checkpoint(
model=sub_model) if sub_model_export_name else None
model=sub_model,
global_step=optimizer.iterations) if sub_model_export_name else None
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
......@@ -405,13 +408,16 @@ def run_customized_training_loop(
checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps:
if current_step % steps_per_epoch == 0:
callback_list.on_epoch_begin(int(current_step / steps_per_epoch) + 1)
# Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round.
train_loss_metric.reset_states()
for metric in train_metrics + model.metrics:
metric.reset_states()
_run_callbacks_on_batch_begin(current_step)
callback_list.on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
......@@ -426,7 +432,7 @@ def run_customized_training_loop(
tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric)
current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
......@@ -443,35 +449,43 @@ def run_customized_training_loop(
train_summary_writer.flush()
logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last
# step of training.
# Save a submodel with the step in the file name after each epoch.
if sub_model_export_name:
_save_checkpoint(
strategy, sub_model_checkpoint, model_dir,
'%s_step_%d.ckpt' % (sub_model_export_name, current_step))
# Save model checkpoints and run validation steps after each epoch
# (with the exception of the final epoch which is handled after the
# training loop).
if current_step < total_training_steps:
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if sub_model_export_name:
_save_checkpoint(
strategy, sub_model_checkpoint, model_dir,
'%s_step_%d.ckpt' % (sub_model_export_name, current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics:
metric.reset_states()
logs = None
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics:
metric.reset_states()
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if sub_model_export_name:
_save_checkpoint(strategy, sub_model_checkpoint, model_dir,
'%s.ckpt' % sub_model_export_name)
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
training_summary = {
'total_training_steps': total_training_steps,
......
......@@ -135,6 +135,27 @@ def check_eventfile_for_keyword(keyword, summary_dir):
return any(summaries_with_matching_keyword(keyword, summary_dir))
class RecordingCallback(tf.keras.callbacks.Callback):
def __init__(self):
self.batch_begin = [] # (batch, logs)
self.batch_end = [] # (batch, logs)
self.epoch_begin = [] # (epoch, logs)
self.epoch_end = [] # (epoch, logs)
def on_batch_begin(self, batch, logs=None):
self.batch_begin.append((batch, logs))
def on_batch_end(self, batch, logs=None):
self.batch_end.append((batch, logs))
def on_epoch_begin(self, epoch, logs=None):
self.epoch_begin.append((epoch, logs))
def on_epoch_end(self, epoch, logs=None):
self.epoch_end.append((epoch, logs))
class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
......@@ -156,6 +177,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
eval_input_fn=input_fn,
eval_steps=10,
init_checkpoint=None,
sub_model_export_name='my_submodel_name',
metric_fn=metric_fn,
custom_callbacks=None,
run_eagerly=run_eagerly)
......@@ -188,7 +210,20 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
distribution, model_dir, steps_per_loop=10, run_eagerly=False)
# Two checkpoints should be saved after two epochs.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*')))
files = map(os.path.basename,
tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*index')))
self.assertCountEqual(['ctl_step_20.ckpt-1.index',
'ctl_step_40.ckpt-2.index'], files)
# Three submodel checkpoints should be saved after two epochs (one after
# each epoch plus one final).
files = map(os.path.basename,
tf.io.gfile.glob(os.path.join(model_dir,
'my_submodel_name*index')))
self.assertCountEqual(['my_submodel_name.ckpt-3.index',
'my_submodel_name_step_20.ckpt-1.index',
'my_submodel_name_step_40.ckpt-2.index'], files)
self.assertNotEmpty(
tf.io.gfile.glob(
os.path.join(model_dir, 'summaries/training_summary*')))
......@@ -210,6 +245,41 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword('mean_input',
os.path.join(model_dir, 'summaries/eval')))
@combinations.generate(eager_strategy_combinations())
def test_train_check_callbacks(self, distribution):
model_dir = self.get_temp_dir()
callback = RecordingCallback()
callbacks = [callback]
input_fn = create_fake_data_input_fn(
batch_size=8, features_shape=[128], num_classes=3)
model_training_utils.run_customized_training_loop(
strategy=distribution,
model_fn=self._model_fn,
loss_fn=tf.keras.losses.categorical_crossentropy,
model_dir=model_dir,
steps_per_epoch=20,
steps_per_loop=10,
epochs=2,
train_input_fn=input_fn,
eval_input_fn=input_fn,
eval_steps=10,
init_checkpoint=None,
metric_fn=metric_fn,
custom_callbacks=callbacks,
run_eagerly=False)
self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})])
epoch_ends, epoch_end_infos = zip(*callback.epoch_end)
self.assertEqual(list(epoch_ends), [1, 2])
for info in epoch_end_infos:
self.assertIn('accuracy', info)
self.assertEqual(callback.batch_begin,
[(0, {}), (10, {}), (20, {}), (30, {})])
batch_ends, batch_end_infos = zip(*callback.batch_end)
self.assertEqual(list(batch_ends), [9, 19, 29, 39])
for info in batch_end_infos:
self.assertIn('loss', info)
@combinations.generate(
combinations.combine(
distribution=[
......
......@@ -56,6 +56,7 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS
......@@ -85,7 +86,7 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_classifier_dataset(
input_file_pattern,
tf.io.gfile.glob(input_file_pattern),
max_seq_length,
batch_size,
is_training=is_training,
......@@ -125,7 +126,8 @@ def run_bert_classifier(strategy,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.end_lr, FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
......@@ -196,7 +198,7 @@ def run_keras_compile_fit(model_dir,
with strategy.scope():
training_dataset = train_input_fn()
evaluation_dataset = eval_input_fn()
evaluation_dataset = eval_input_fn() if eval_input_fn else None
bert_model, sub_model = model_fn()
optimizer = bert_model.optimizer
......@@ -329,7 +331,9 @@ def run_bert(strategy,
input_meta_data,
model_config,
train_input_fn=None,
eval_input_fn=None):
eval_input_fn=None,
init_checkpoint=None,
custom_callbacks=None):
"""Run BERT training."""
if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
......@@ -356,14 +360,14 @@ def run_bert(strategy,
if not strategy:
raise ValueError('Distribution strategy has not been specified.')
if not custom_callbacks:
custom_callbacks = []
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
custom_callbacks.append(keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
logdir=FLAGS.model_dir))
trained_model = run_bert_classifier(
strategy,
......@@ -376,7 +380,7 @@ def run_bert(strategy,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
init_checkpoint or FLAGS.init_checkpoint,
train_input_fn,
eval_input_fn,
run_eagerly=FLAGS.run_eagerly,
......@@ -394,9 +398,12 @@ def run_bert(strategy,
return trained_model
def main(_):
# Users should always run this script under TF 2.x
def custom_main(custom_callbacks=None):
"""Run classification.
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
"""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......@@ -421,7 +428,12 @@ def main(_):
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_bert(strategy, input_meta_data, bert_config, train_input_fn,
eval_input_fn)
eval_input_fn, custom_callbacks=custom_callbacks)
def main(_):
# Users should always run this script under TF 2.x
custom_main(custom_callbacks=None)
if __name__ == '__main__':
......
......@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
......@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS
def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, global_batch_size):
max_predictions_per_seq, global_batch_size,
use_next_sentence_label=True):
"""Returns input dataset from input file string."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
......@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq,
batch_size,
is_training=True,
input_pipeline_context=ctx)
input_pipeline_context=ctx,
use_next_sentence_label=use_next_sentence_label)
return train_dataset
return _dataset_fn
......@@ -92,20 +96,26 @@ def run_customized_training(strategy,
epochs,
initial_lr,
warmup_steps,
end_lr,
optimizer_type,
input_files,
train_batch_size):
train_batch_size,
use_next_sentence_label=True):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_predictions_per_seq,
train_batch_size)
train_batch_size,
use_next_sentence_label)
def _get_pretrain_model():
"""Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq)
bert_config, max_seq_length, max_predictions_per_seq,
use_next_sentence_label=use_next_sentence_label)
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
initial_lr, steps_per_epoch * epochs, warmup_steps,
end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
......@@ -151,8 +161,11 @@ def run_bert_pretrain(strategy):
FLAGS.num_train_epochs,
FLAGS.learning_rate,
FLAGS.warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type,
FLAGS.input_files,
FLAGS.train_batch_size)
FLAGS.train_batch_size,
FLAGS.use_next_sentence_label)
def main(_):
......
......@@ -20,7 +20,6 @@ from __future__ import print_function
import json
import os
import tempfile
import time
from absl import app
......@@ -48,11 +47,13 @@ FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False):
run_eagerly=False,
init_checkpoint=None):
"""Run bert squad training."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly)
custom_callbacks, run_eagerly, init_checkpoint)
def predict_squad(strategy, input_meta_data):
......@@ -130,18 +131,15 @@ def main(_):
eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['final_f1']
logging.info('SQuAD eval F1-score: %f', f1_score)
if (not strategy) or strategy.extended.should_save_summary:
summary_dir = os.path.join(FLAGS.model_dir, 'summaries')
else:
summary_dir = tempfile.mkdtemp()
summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval'))
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
summary_writer = tf.summary.create_file_writer(summary_dir)
with summary_writer.as_default():
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()
# Wait for some time, for the depending mldash/tensorboard jobs to finish
# exporting the final F1-score.
# Also write eval_metrics to json file.
squad_lib_wp.write_to_json_files(
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
time.sleep(60)
......
......@@ -88,6 +88,7 @@ def define_common_squad_flags():
'another.')
common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS
......@@ -159,8 +160,12 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
def predict_squad_customized(strategy,
input_meta_data,
bert_config,
checkpoint_path,
predict_tfrecord_path,
num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
......@@ -179,7 +184,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
if checkpoint_path is None:
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
......@@ -215,7 +221,8 @@ def train_squad(strategy,
input_meta_data,
bert_config,
custom_callbacks=None,
run_eagerly=False):
run_eagerly=False,
init_checkpoint=None):
"""Run bert squad training."""
if strategy:
logging.info('Training using customized training loop with distribution'
......@@ -244,7 +251,9 @@ def train_squad(strategy,
hub_module_trainable=FLAGS.hub_module_trainable)
optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs,
warmup_steps)
warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type)
squad_model.optimizer = performance.configure_optimizer(
optimizer,
......@@ -270,7 +279,7 @@ def train_squad(strategy,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
init_checkpoint=init_checkpoint or FLAGS.init_checkpoint,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks,
explicit_allreduce=False,
......@@ -278,7 +287,7 @@ def train_squad(strategy,
def prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib):
strategy, input_meta_data, tokenizer, bert_config, squad_lib, checkpoint):
"""Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
......@@ -326,8 +335,9 @@ def prediction_output_squad(
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
all_results = predict_squad_customized(
strategy, input_meta_data, bert_config,
checkpoint, eval_writer.filename, num_steps)
all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output(
......@@ -359,18 +369,34 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
def predict_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them to hard drive."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib)
strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
def eval_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them against ground truth."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib)
strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import csv
import importlib
import os
from absl import logging
......@@ -403,6 +404,7 @@ class TfdsProcessor(DataProcessor):
(TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number).
data_dir: Optional TFDS source root directory.
module_import: Optional Dataset module to import.
train_split: Name of the train split (defaults to `train`).
dev_split: Name of the dev split (defaults to `validation`).
test_split: Name of the test split (defaults to `test`).
......@@ -418,6 +420,9 @@ class TfdsProcessor(DataProcessor):
process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params)
if self.module_import:
importlib.import_module(self.module_import)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir,
with_info=True)
self._labels = list(range(info.features[self.label_key].num_classes))
......@@ -428,6 +433,7 @@ class TfdsProcessor(DataProcessor):
d = {k.strip(): v.strip() for k, v in tuples}
self.dataset_name = d["dataset"] # Required.
self.data_dir = d.get("data_dir", None)
self.module_import = d.get("module_import", None)
self.train_split = d.get("train_split", "train")
self.dev_split = d.get("dev_split", "validation")
self.test_split = d.get("test_split", "test")
......@@ -578,6 +584,7 @@ def file_based_convert_examples_to_features(examples, label_list,
output_file):
"""Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
import functools
import json
import os
from absl import app
from absl import flags
......@@ -191,6 +192,7 @@ def main(_):
else:
input_meta_data = generate_squad_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
writer.write(json.dumps(input_meta_data, indent=4) + "\n")
......
......@@ -100,13 +100,14 @@ class TrainingInstance(object):
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files):
max_predictions_per_seq, output_files,
gzip_compress):
"""Create TF example files from `TrainingInstance`s."""
writers = []
for output_file in output_files:
writers.append(
tf.io.TFRecordWriter(
output_file, options="GZIP" if FLAGS.gzip_compress else ""))
output_file, options="GZIP" if gzip_compress else ""))
writer_index = 0
......@@ -183,9 +184,15 @@ def create_float_feature(values):
return feature
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng):
def create_training_instances(input_files,
tokenizer,
max_seq_length,
dupe_factor,
short_seq_prob,
masked_lm_prob,
max_predictions_per_seq,
rng,
do_whole_word_mask=False):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
......@@ -221,7 +228,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask))
rng.shuffle(instances)
return instances
......@@ -229,7 +237,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
......@@ -327,7 +336,8 @@ def create_instances_from_document(
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
......@@ -347,7 +357,8 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
......@@ -363,7 +374,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
......@@ -456,7 +467,7 @@ def main(_):
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng)
rng, FLAGS.do_whole_word_mask)
output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***")
......@@ -464,7 +475,8 @@ def main(_):
logging.info(" %s", output_file)
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files)
FLAGS.max_predictions_per_seq, output_files,
FLAGS.gzip_compress)
if __name__ == "__main__":
......
......@@ -14,9 +14,15 @@ If `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache used
for auto-agressive decoding.
* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
heads attention, as decribed in ["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
* [Transformer](transformer.py) implements an optionally masked transformer as
described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with ReZero
described in ["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
* [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding lookups designed for TPU-based models.
* [PositionalEmbedding](position_embedding.py) creates a positional embedding
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,6 +18,8 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.transformer import Transformer
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
......@@ -186,15 +186,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
class CachedAttention(MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding.
Arguments:
num_heads: Number of attention heads.
head_size: Size of each attention head.
**kwargs: Other keyword arguments inherit from `Attention` class.
Arguments are the same as `MultiHeadAttention` layer.
"""
def __init__(self, num_heads, head_size, **kwargs):
super(CachedAttention, self).__init__(num_heads, head_size, **kwargs)
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
......
......@@ -147,6 +147,8 @@ class DenseEinsum(tf.keras.layers.Layer):
config = {
"output_shape":
self._output_shape,
"num_summed_dimensions":
self._num_summed_dimensions,
"activation":
tf.keras.activations.serialize(self._activation),
"use_bias":
......
......@@ -21,8 +21,6 @@ from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package="Text")
class OnDeviceEmbedding(tf.keras.layers.Layer):
......@@ -78,8 +76,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
super(OnDeviceEmbedding, self).build(input_shape)
def call(self, inputs):
input_shape = tf_utils.get_shape_list(inputs, expected_rank=2)
input_shape.append(self._embedding_width)
flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot:
one_hot_data = tf.one_hot(
......@@ -87,6 +83,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embeddings = tf.matmul(one_hot_data, self.embeddings)
else:
embeddings = tf.gather(self.embeddings, flat_inputs)
embeddings = tf.reshape(embeddings, input_shape)
embeddings = tf.reshape(
embeddings,
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
return embeddings
......@@ -111,15 +111,10 @@ class PositionEmbedding(tf.keras.layers.Layer):
def call(self, inputs):
"""Implements call() for the layer."""
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
if self._use_dynamic_slicing:
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
seq_length = input_shape[1]
width = input_shape[2]
position_embeddings = tf.expand_dims(
tf.slice(self._position_embeddings, [0, 0], [seq_length, width]),
axis=0)
position_embeddings = self._position_embeddings[:input_shape[1], :]
else:
position_embeddings = tf.expand_dims(self._position_embeddings, axis=0)
position_embeddings = self._position_embeddings
return position_embeddings
return tf.broadcast_to(position_embeddings, input_shape)
......@@ -40,7 +40,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [1, sequence_length, width]
expected_output_shape = [None, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype)
......@@ -55,7 +55,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [1, sequence_length, width]
expected_output_shape = [None, sequence_length, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float16, output_tensor.dtype)
......@@ -72,7 +72,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# When using dynamic positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions - but may be None if
# the input shape is None there.
expected_output_shape = [1, None, width]
expected_output_shape = [None, None, width]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
def test_dynamic_layer_slicing(self):
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based rezero-transformer block layer (Transformer with ReZero)."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class ReZeroTransformer(tf.keras.layers.Layer):
"""Transformer layer with ReZero.
This layer implements the Transformer from "Attention Is All You Need".
(https://arxiv.org/abs/1706.03762).
The residual connection implements the ReZero method.
(https://arxiv.org/abs/2003.04887)
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_layer_norm: If add layer_norm on top of the ReZero.
"""
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_layer_norm=False,
**kwargs):
super(ReZeroTransformer, self).__init__(**kwargs)
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_layer_norm = use_layer_norm
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerLayer, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
head_size=self._attention_head_size,
dropout_rate=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
self._attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output")
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation)
self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
self._rezero_a = self.add_weight(
name="rezero_alpha",
initializer=tf.keras.initializers.Zeros(),
trainable=True, dtype=tf.float32)
super(ReZeroTransformer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"use_layer_norm":
self._use_layer_norm,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
}
base_config = super(ReZeroTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def reset_rezero(self):
self._rezero_a.assign(0.)
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None:
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output)
attention_output = input_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
attention_output = self._attention_layer_norm(attention_output)
else:
attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
layer_output = attention_output + tf.cast(self._rezero_a * layer_output,
tf.float32)
if self._use_layer_norm:
layer_output = self._output_layer_norm(layer_output)
return layer_output
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based rezero-transformer block layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import rezero_transformer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
def tearDown(self):
super(TransformerWithReZeroLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32')
def test_layer_invocation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length))
output_tensor = test_layer([data_tensor, mask_tensor])
# Create a model from the test layer.
model = tf.keras.Model([data_tensor, mask_tensor], output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
_ = model.predict([input_data, mask_data])
def test_rezero_without_layer_norm(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
use_layer_norm=False)
input_length, width = 16, 30
input_tensor = tf.keras.Input(shape=(input_length, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_data = np.random.rand(2, input_length, width)
test_layer._rezero_a.assign(1.0)
test_layer.reset_rezero()
output_data = model.predict(input_data)
self.assertAllClose(input_data, output_data)
def test_rezero_with_layer_norm(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
use_layer_norm=True)
input_length, width = 16, 30
input_tensor = tf.keras.Input(shape=(input_length, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_data = np.random.rand(2, input_length, width) + 2.0
output_data = model.predict(input_data)
input_data_normed = (
input_data - np.mean(input_data, axis=-1, keepdims=True)) / (
np.std(input_data, axis=-1, keepdims=True))
self.assertAllClose(input_data_normed, output_data)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment