Commit afd5579f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into context_tf2

parents dcd96e02 567bd18d
...@@ -271,6 +271,44 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -271,6 +271,44 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS.strategy_type = 'tpu' FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0) self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_2x2_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self): def benchmark_2x2_tpu_spinenet_coco(self):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs.""" """Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
......
...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc ...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
TPU_DATA_DIR = 'gs://mlcompass-data/transformer'
GPU_DATA_DIR = os.getenv('TMPDIR')
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014' EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark):
Code under test for the Transformer Keras models report the same data and Code under test for the Transformer Keras models report the same data and
require the same FLAG setup. require the same FLAG setup.
""" """
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None, def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None, tpu=None): flag_methods=None, tpu=None):
self._set_data_files(root_data_dir=root_data_dir)
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _set_data_files(self, root_data_dir=None, tpu_run=False):
"""Sets train_data_dir, vocab_file, bleu_source and bleu_ref."""
# Use remote storage for TPU, remote storage for GPU if defined, else
# use environment provided root_data_dir.
if tpu_run:
root_data_dir = TPU_DATA_DIR
elif GPU_DATA_DIR is not None:
root_data_dir = GPU_DATA_DIR
root_data_dir = root_data_dir if root_data_dir else '' root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir, self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME) TRANSFORMER_EN2DE_DATA_DIR_NAME)
self.vocab_file = os.path.join(root_data_dir, self.vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME, TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768') 'vocab.ende.32768')
self.bleu_source = os.path.join(root_data_dir, self.bleu_source = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.en') 'newstest2014.en')
self.bleu_ref = os.path.join(root_data_dir, self.bleu_ref = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de') 'newstest2014.de')
if default_flags is None: def _set_data_file_flags(self):
default_flags = {} """Sets the FLAGS for the data files."""
default_flags['data_dir'] = self.train_data_dir FLAGS.data_dir = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
super(TransformerBenchmark, self).__init__( FLAGS['bleu_source'].value = self.bleu_source
output_dir=output_dir, FLAGS['bleu_ref'].value = self.bleu_ref
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
...@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 2048 FLAGS.batch_size = 2048
FLAGS.train_steps = 1000 FLAGS.train_steps = 1000
...@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Iterations are not epochs, an iteration is a number of steps between evals. Iterations are not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
not epochs, an iteration is a number of steps between evals. not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
root_data_dir=root_data_dir, batch_per_gpu=3072, root_data_dir=root_data_dir, batch_per_gpu=3072,
tpu=tpu) tpu=tpu)
def benchmark_2x2_tpu(self): def _set_df_common(self):
"""Port of former snaggletooth transformer_big model on 2x2.""" self._set_data_files(tpu_run=True)
self._setup() FLAGS.data_dir = self.train_data_dir
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu') FLAGS.vocab_file = self.vocab_file
FLAGS.distribution_strategy = 'tpu'
FLAGS.padded_decode = True
FLAGS.train_steps = 300 FLAGS.train_steps = 300
FLAGS.log_steps = 150 FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150 FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True FLAGS.static_batch = True
FLAGS.use_ctl = True FLAGS.use_ctl = True
FLAGS.batch_size = 6144 FLAGS.enable_checkpointing = False
FLAGS.max_length = 64 FLAGS.max_length = 64
FLAGS.decode_batch_size = 32 FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97 FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.batch_size = 6144
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_mlir(self):
"""Run transformer_big model on 2x2 with the MLIR Bridge enabled."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mlir')
FLAGS.batch_size = 6144
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu(self): def benchmark_4x4_tpu(self):
"""Port of former GCP transformer_big model on 4x4.""" """Port of former GCP transformer_big model on 4x4."""
self._setup() self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu_mlir(self): def benchmark_4x4_tpu_mlir(self):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled.""" """Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self._setup() self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') self._set_df_common()
FLAGS.train_steps = 300 FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_mlir')
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {}, "colab": {},
...@@ -104,7 +104,7 @@ ...@@ -104,7 +104,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -128,7 +128,7 @@ ...@@ -128,7 +128,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -185,7 +185,7 @@ ...@@ -185,7 +185,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -204,12 +204,12 @@ ...@@ -204,12 +204,12 @@
"id": "9uFskufsR2LT" "id": "9uFskufsR2LT"
}, },
"source": [ "source": [
"You can get a pre-trained BERT encoder from TensorFlow Hub here:" "You can get a pre-trained BERT encoder from [TensorFlow Hub](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2):"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -252,7 +252,7 @@ ...@@ -252,7 +252,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -267,7 +267,7 @@ ...@@ -267,7 +267,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -290,7 +290,7 @@ ...@@ -290,7 +290,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -313,7 +313,7 @@ ...@@ -313,7 +313,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -336,7 +336,7 @@ ...@@ -336,7 +336,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -376,7 +376,7 @@ ...@@ -376,7 +376,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -404,7 +404,7 @@ ...@@ -404,7 +404,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -446,7 +446,7 @@ ...@@ -446,7 +446,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -469,7 +469,7 @@ ...@@ -469,7 +469,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -490,7 +490,7 @@ ...@@ -490,7 +490,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -514,7 +514,7 @@ ...@@ -514,7 +514,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -562,7 +562,7 @@ ...@@ -562,7 +562,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -587,7 +587,7 @@ ...@@ -587,7 +587,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -617,7 +617,7 @@ ...@@ -617,7 +617,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -661,7 +661,7 @@ ...@@ -661,7 +661,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -691,7 +691,7 @@ ...@@ -691,7 +691,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -737,7 +737,7 @@ ...@@ -737,7 +737,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -769,7 +769,7 @@ ...@@ -769,7 +769,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -793,7 +793,7 @@ ...@@ -793,7 +793,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -816,7 +816,7 @@ ...@@ -816,7 +816,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -845,7 +845,7 @@ ...@@ -845,7 +845,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -870,7 +870,7 @@ ...@@ -870,7 +870,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -908,7 +908,7 @@ ...@@ -908,7 +908,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -943,7 +943,7 @@ ...@@ -943,7 +943,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -986,7 +986,7 @@ ...@@ -986,7 +986,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1023,7 +1023,7 @@ ...@@ -1023,7 +1023,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1055,7 +1055,7 @@ ...@@ -1055,7 +1055,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1071,7 +1071,7 @@ ...@@ -1071,7 +1071,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1096,7 +1096,7 @@ ...@@ -1096,7 +1096,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1110,7 +1110,7 @@ ...@@ -1110,7 +1110,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1176,7 +1176,7 @@ ...@@ -1176,7 +1176,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1201,7 +1201,7 @@ ...@@ -1201,7 +1201,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1240,7 +1240,7 @@ ...@@ -1240,7 +1240,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1273,7 +1273,7 @@ ...@@ -1273,7 +1273,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1306,7 +1306,7 @@ ...@@ -1306,7 +1306,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1351,7 +1351,7 @@ ...@@ -1351,7 +1351,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1379,7 +1379,7 @@ ...@@ -1379,7 +1379,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1406,17 +1406,44 @@ ...@@ -1406,17 +1406,44 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "lo6479At4sP1" "id": "GDWrHm0BGpbX"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Note: 350MB download.\n", "# Note: 350MB download.\n",
"import tensorflow_hub as hub\n", "import tensorflow_hub as hub"
"hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n", ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5"
},
"outputs": [],
"source": [
"hub_model_name = \"bert_en_uncased_L-12_H-768_A-12\" #@param [\"bert_en_uncased_L-24_H-1024_A-16\", \"bert_en_wwm_cased_L-24_H-1024_A-16\", \"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_wwm_uncased_L-24_H-1024_A-16\", \"bert_en_cased_L-24_H-1024_A-16\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_zh_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}\",\n",
" trainable=True)\n",
"\n", "\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")" "print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
] ]
...@@ -1433,7 +1460,7 @@ ...@@ -1433,7 +1460,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1466,7 +1493,7 @@ ...@@ -1466,7 +1493,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1491,7 +1518,7 @@ ...@@ -1491,7 +1518,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1504,7 +1531,7 @@ ...@@ -1504,7 +1531,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1545,7 +1572,7 @@ ...@@ -1545,7 +1572,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1569,7 +1596,7 @@ ...@@ -1569,7 +1596,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1592,7 +1619,7 @@ ...@@ -1592,7 +1619,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1617,7 +1644,7 @@ ...@@ -1617,7 +1644,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1643,7 +1670,7 @@ ...@@ -1643,7 +1670,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1661,7 +1688,7 @@ ...@@ -1661,7 +1688,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1688,7 +1715,7 @@ ...@@ -1688,7 +1715,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1714,7 +1741,7 @@ ...@@ -1714,7 +1741,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1733,7 +1760,7 @@ ...@@ -1733,7 +1760,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1761,7 +1788,7 @@ ...@@ -1761,7 +1788,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1795,7 +1822,7 @@ ...@@ -1795,7 +1822,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
......
...@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text, ...@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise ValueError('model must be a tf.keras.Model object.') raise ValueError('model must be a tf.keras.Model object.')
if checkpoint_dir: if checkpoint_dir:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if restore_model_using_load_weights: if restore_model_using_load_weights:
model_weight_path = os.path.join(checkpoint_dir, 'checkpoint') model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
assert tf.io.gfile.exists(model_weight_path) assert tf.io.gfile.exists(model_weight_path)
model.load_weights(model_weight_path) model.load_weights(model_weight_path)
# tf.train.Checkpoint API was used via custom training loop logic.
else: else:
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model)
......
...@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor): ...@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor):
return "COLA" return "COLA"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
# Only the test set has a header # Only the test set has a header.
if set_type == "test" and i == 0: if set_type == "test" and i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor): ...@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor):
class MnliProcessor(DataProcessor): class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version).""" """Processor for the MultiNLI data set (GLUE version)."""
def __init__(self,
mnli_type="matched",
process_text_fn=tokenization.convert_to_unicode):
super(MnliProcessor, self).__init__(process_text_fn)
if mnli_type not in ("matched", "mismatched"):
raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
self.mnli_type = mnli_type
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
...@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor): ...@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( if self.mnli_type == "matched":
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), return self._create_examples(
"dev_matched") self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_mismatched")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( if self.mnli_type == "matched":
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor): ...@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor):
return "MNLI" return "MNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0])) guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
...@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor): ...@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor):
return "MRPC" return "MRPC"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor): ...@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:]) self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor): ...@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:]) self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "dev-%d" % i guid = "dev-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor): ...@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:] lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "test-%d" % i guid = "test-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor): ...@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor):
return "QNLI" return "QNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, 1) guid = "%s-%s" % (set_type, 1)
...@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor): ...@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor):
return "QQP" return "QQP"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
...@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor): ...@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor):
return "RTE" return "RTE"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
...@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor): ...@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor):
return "SST-2" return "SST-2"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor): ...@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor):
return "STS-B" return "STS-B"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
...@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor): ...@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return "TFDS_" + self.dataset_name return "TFDS_" + self.dataset_name
def _create_examples(self, split_name, set_type): def _create_examples(self, split_name, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
if split_name not in self.dataset: if split_name not in self.dataset:
raise ValueError("Split {} not available.".format(split_name)) raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator() dataset = self.dataset[split_name].as_numpy_iterator()
...@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor): ...@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor):
return "WNLI" return "WNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
...@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor): ...@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor):
"multinli.train.%s.tsv" % language))[1:]) "multinli.train.%s.tsv" % language))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor): ...@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "dev-%d" % i guid = "dev-%d" % i
...@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor): ...@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv")) lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages} examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "test-%d" % i guid = "test-%d" % i
...@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor): ...@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv")) lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor): ...@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor):
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv")) lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "dev-%d" % i guid = "dev-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor): ...@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv")) lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "test-%d" % i guid = "test-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor): ...@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor):
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv")) lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor): ...@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv")) lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "dev-%d" % i guid = "dev-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor): ...@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv")) lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = f"test-{i}" guid = f"test-{i}"
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples, ...@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
tf.io.gfile.makedirs(os.path.dirname(output_file)) tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file) writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples): for ex_index, example in enumerate(examples):
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples)) logging.info("Writing example %d of %d", ex_index, len(examples))
......
...@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI", ...@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
"only and for XNLI is all languages combined. Same for " "only and for XNLI is all languages combined. Same for "
"PAWS-X.") "PAWS-X.")
# XNLI task specific flag. # MNLI task-specific flag.
flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"xnli_language", "en", "xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data " "Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# PAWS-X task specific flag. # PAWS-X task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"pawsx_language", "en", "pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data " "Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# Retrieva task specific flags # Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"], flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring") "The name of sentence retrieval task for scoring")
# Tagging task specific flags # Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"], flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.") "The name of BERT tagging (token classification) task.")
# BERT Squad task specific flags. # BERT Squad task-specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"squad_data_file", None, "squad_data_file", None,
"The input data file in for generating training data for BERT squad task.") "The input data file in for generating training data for BERT squad task.")
...@@ -179,7 +184,8 @@ def generate_classifier_dataset(): ...@@ -179,7 +184,8 @@ def generate_classifier_dataset():
"cola": "cola":
classifier_data_lib.ColaProcessor, classifier_data_lib.ColaProcessor,
"mnli": "mnli":
classifier_data_lib.MnliProcessor, functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
"mrpc": "mrpc":
classifier_data_lib.MrpcProcessor, classifier_data_lib.MrpcProcessor,
"qnli": "qnli":
......
...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense ...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
def _build_attention_equation(qkv_rank, attn_axes): def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation. """Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as: Query, key, value inputs after projection are expected to have the shape as:
...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes): ...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels) <query attention dims>, num_heads, channels)
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention. attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns: Returns:
Einsum equations. Einsum equations.
""" """
target_notation = _CHR_IDX[:qkv_rank] target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim. # `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,))) batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = qkv_rank letter_offset = rank
source_notation = "" source_notation = ""
for i in range(qkv_rank): for i in range(rank):
if i in batch_dims or i == qkv_rank - 1: if i in batch_dims or i == rank - 1:
source_notation += target_notation[i] source_notation += target_notation[i]
else: else:
source_notation += _CHR_IDX[letter_offset] source_notation += _CHR_IDX[letter_offset]
...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim. sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head return_attention_scores: bool, if `True`, returns the multi-head attention
attention scores as an additional output argument. scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
""" """
def __init__(self, def __init__(self,
...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._attention_axes = (attention_axes,) self._attention_axes = (attention_axes,)
else: else:
self._attention_axes = attention_axes self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self): def get_config(self):
config = { config = {
...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config() base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape): def _build_from_signature(self, query, value, key=None):
inputs_len = len(input_shape) """Builds layers and variables.
if inputs_len > 3 or inputs_len < 2:
raise ValueError( Once the method is called, self._built_from_signature will be set to True.
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. " Args:
"Given length: %d" % inputs_len) query: query tensor or TensorShape.
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape) value: value tensor or TensorShape.
query_shape = tensor_shapes[0] key: key tensor or TensorShape.
value_shape = tensor_shapes[1] """
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -271,84 +293,79 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -271,84 +293,79 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1 free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2) free_dims, bound_dims=1, output_dims=2)
self._query_dense = EinsumDense( self._query_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="query", name="query",
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2) key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = EinsumDense( self._key_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="key", name="key",
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_shape.rank - 1, bound_dims=1, output_dims=2) value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = EinsumDense( self._value_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_size]), [self._num_heads, self._value_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="value", name="value",
**common_kwargs) **common_kwargs)
# Builds the attention computations for multi-head dot product attention. # Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once it # These computations could be wrapped into the keras attention layer once
# support mult-head einsum computations. # it support mult-head einsum computations.
self._build_attention(output_rank) self.build_attention(output_rank)
if self._output_shape: if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized): if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape] output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else: else:
output_shape = self._output_shape output_shape = [query_shape[-1]]
else: einsum_equation, bias_axes, output_rank = _build_proj_equation(
output_shape = [query_shape[-1]] free_dims, bound_dims=2, output_dims=len(output_shape))
einsum_equation, bias_axes, output_rank = _build_proj_equation( self._output_dense = EinsumDense(
free_dims, bound_dims=2, output_dims=len(output_shape)) einsum_equation,
self._output_dense = EinsumDense( output_shape=_get_output_shape(output_rank - 1, output_shape),
einsum_equation, bias_axes=bias_axes if self._use_bias else None,
output_shape=_get_output_shape(output_rank - 1, output_shape), name="attention_output",
bias_axes=bias_axes if self._use_bias else None, **common_kwargs)
name="attention_output",
**common_kwargs) def build_attention(self, rank):
super(MultiHeadAttention, self).build(input_shape)
def _build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product costomize attention computation to replace the default dot-product
attention. attention.
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
""" """
if self._attention_axes is None: if self._attention_axes is None:
self._attention_axes = tuple(range(1, qkv_rank - 2)) self._attention_axes = tuple(range(1, rank - 2))
else: else:
self._attention_axes = tuple(self._attention_axes) self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = ( self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes)) _build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple( norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax( self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes) mask_expansion_axes=[1], normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _compute_attention(self, def compute_attention(self, query, key, value, attention_mask=None):
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected This function defines the computation inside `call` with projected
...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation. attention implementation.
Args: Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`. query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`. key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`. value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights. attention_scores: Multi-headed attention weights.
""" """
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S] # `attention_scores` = [B, N, T, S]
...@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H] # `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation, attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor) attention_scores_dropout, value)
return attention_output, attention_scores return attention_output, attention_scores
def call(self, inputs, attention_mask=None): def call(self, query, value, key=None, attention_mask=None):
"""Implements the forward pass. """Implements the forward pass.
Size glossary: Size glossary:
...@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target. * Value (source) attention axes shape (S), the rank must match the target.
Args: Args:
inputs: List of the following tensors: query: Query `Tensor` of shape `[B, T, dim]`.
* query: Query `Tensor` of shape `[B, T, dim]`. value: Value `Tensor` of shape `[B, S, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`. key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will `value` for both `key` and `value`, which is the most common case.
use `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention attention
axes. axes.
""" """
inputs_len = len(inputs) if not self._built_from_signature:
if inputs_len > 3 or inputs_len < 2: self._build_from_signature(query=query, value=value, key=key)
raise ValueError( if key is None:
"Expects inputs list of length 2 or 3, namely [query, value] or " key = value
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, T, N ,H] # `query` = [B, T, N ,H]
query_tensor = self._query_dense(query) query = self._query_dense(query)
# `key_tensor` = [B, S, N, H] # `key` = [B, S, N, H]
key_tensor = self._key_dense(key) key = self._key_dense(key)
# `value_tensor` = [B, S, N, H] # `value` = [B, S, N, H]
value_tensor = self._value_dense(value) value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention( attention_output, attention_scores = self.compute_attention(
query_tensor, key_tensor, value_tensor, attention_mask) query, key, value, attention_mask)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
...@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention): ...@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer. Arguments are the same as `MultiHeadAttention` layer.
""" """
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step): def _update_cache(self, key, value, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors.""" """Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values. # Combines cached keys and values with new keys and values.
if decode_loop_step is not None: if decode_loop_step is not None:
# TPU special case. # TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1] key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype), tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1]) [1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1] value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype), tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1]) [1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices value = cache["value"] + value * indices
else: else:
key_tensor = tf.concat( key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1) value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
# Update cache # Update cache
cache["key"] = key_tensor cache["key"] = key
cache["value"] = value_tensor cache["value"] = value
return key_tensor, value_tensor return key, value
def call(self, def call(self,
inputs, query,
value,
key=None,
attention_mask=None, attention_mask=None,
cache=None, cache=None,
decode_loop_step=None): decode_loop_step=None):
from_tensor = inputs[0] if not self._built_from_signature:
to_tensor = inputs[1] self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of sequences) # B = batch size (number of sequences)
...@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention): ...@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length # T = `to_tensor` sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query = self._query_dense(query)
# `key_tensor` = [B, T, N, H] # `key` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor) key = self._key_dense(key)
# `value_tensor` = [B, T, N, H] # `value` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor) value = self._value_dense(value)
if cache: if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor, key, value = self._update_cache(key, value, cache, decode_loop_step)
cache, decode_loop_step)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
attention_scores = tf.multiply(attention_scores, attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size))) 1.0 / math.sqrt(float(self._key_size)))
...@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention): ...@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores = self._dropout_layer(attention_scores) attention_scores = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
attention_output = tf.einsum(self._combine_equation, attention_scores, attention_output = tf.einsum(self._combine_equation, attention_scores,
value_tensor) value)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
return attention_output, attention_scores, cache return attention_output, attention_scores, cache
......
...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase): ...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size=64) key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element. # one element.
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length)) 2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data], mask_data, cache) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data, cache=cache)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache. # Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data]) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache) self.assertIsNone(cache)
...@@ -243,10 +246,12 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -243,10 +246,12 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32) 2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly. # Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data], masked_output_data, cache = layer(
mask_data, query=from_data,
cache, value=from_data,
decode_loop_step=decode_loop_step) attention_mask=mask_data,
cache=cache,
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
...@@ -110,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -110,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
class MultiChannelAttention(attention.MultiHeadAttention): class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer. """Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple Introduced in, [Generating Representative Headlines for News Stories
cross-attention target sequences. ](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
""" """
def _build_attention(self, qkv_rank): def build_attention(self, rank):
super(MultiChannelAttention, self)._build_attention(qkv_rank) super(MultiChannelAttention, self).build_attention(rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2]) self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None): def call(self,
from_tensor = inputs[0] query,
to_tensor = inputs[1] value,
doc_attention_probs = inputs[2] key=None,
context_attention_weights=None,
attention_mask=None):
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of stories) # B = batch size (number of stories)
# A = num_docs (number of docs) # A = num_docs (number of docs)
# F = `from_tensor` sequence length # F = target sequence length
# T = `to_tensor` sequence length # T = source sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query_tensor = self._query_dense(query)
# `key_tensor` = [B, A, T, N, H] # `key_tensor` = [B, A, T, N, H]
key_tensor = self._key_dense(to_tensor) key_tensor = self._key_dense(key)
# `value_tensor` = [B, A, T, N, H] # `value_tensor` = [B, A, T, N, H]
value_tensor = self._value_dense(to_tensor) value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
...@@ -156,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention): ...@@ -156,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor) value_tensor)
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, attention_output = tf.einsum("BNFA,BAFNH->BFNH", context_attention_weights,
context_layer) context_layer)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
return attention_output return attention_output
...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase): ...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2)) mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint( doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float) 2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, doc_probs], mask_data) outputs = attention_layer(
query=from_data,
value=to_data,
context_attention_weights=doc_probs,
attention_mask=mask_data)
self.assertEqual(outputs.shape, (3, 4, 8)) self.assertEqual(outputs.shape, (3, 4, 8))
......
...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size": self._hidden_size, "hidden_size": self._hidden_size,
"min_timescale": self._min_timescale, "min_timescale": self._min_timescale,
"max_timescale": self._max_timescale, "max_timescale": self._max_timescale,
"length": self._length,
} }
base_config = super(RelativePositionEmbedding, self).get_config() base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
...@@ -213,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -213,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm: if self._use_layer_norm:
......
...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
""" """
def _build_attention(self, qkv_rank): def build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection This function overrides base class to create additional linear projection
...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args: Args:
qkv_rank: the rank of query, key, value tensors after projection. qkv_rank: the rank of query, key, value tensors after projection.
""" """
super(TalkingHeadsAttention, self)._build_attention(qkv_rank) super(TalkingHeadsAttention, self).build_attention(qkv_rank)
# Build an equation: # Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) -> # (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
...@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype=self.dtype, dtype=self.dtype,
trainable=True) trainable=True)
def _compute_attention(self, def compute_attention(self,
query_tensor, query_tensor,
key_tensor, key_tensor,
value_tensor, value_tensor,
attention_mask=None): attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection This function overrides base class to apply additional linear projection
......
...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64) num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(
query=query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
......
...@@ -120,7 +120,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -120,7 +120,9 @@ class Transformer(tf.keras.layers.Layer):
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
# pylint: disable=protected-access # pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape] * 3) # Temporarily handling for checkpoint compatible changes.
self._attention_layer._build_from_signature(
query=input_tensor_shape, value=input_tensor_shape)
self._attention_output_dense = self._attention_layer._output_dense self._attention_output_dense = self._attention_layer._output_dense
# pylint: enable=protected-access # pylint: enable=protected-access
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
...@@ -202,9 +204,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -202,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor + attention_output = self._attention_layer_norm(target_tensor +
attention_output) attention_output)
...@@ -382,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -382,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" % "TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs)) len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
self_attention_inputs = [input_tensor, input_tensor]
self_attention_output, cache = self.self_attention( self_attention_output, cache = self.self_attention(
self_attention_inputs, query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask, attention_mask=self_attention_mask,
cache=cache, cache=cache,
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output) self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm( self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output) input_tensor + self_attention_output)
cross_attn_inputs = dict(
cross_attn_inputs = [self_attention_output, memory] query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities. # Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1]) cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask) attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output) attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output + attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output) attention_output)
......
...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor] attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
......
...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention): ...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs) super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list self.list = call_list
def call(self, inputs, attention_mask=None): def call(self, query, value, attention_mask=None):
self.list.append(True) self.list.append(True)
return super(ValidatedAttentionLayer, self).call( return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask) query, value, attention_mask=attention_mask)
def get_config(self): def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config() config = super(ValidatedAttentionLayer, self).get_config()
......
...@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = new_layer([input_data, mask_data]) _ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :]) self.assertAllClose(new_output_tensor,
output_tensor[:, 0:1, :],
atol=5e-5,
rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
......
...@@ -323,6 +323,28 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -323,6 +323,28 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
self.assertAllEqual(network.get_config(), new_network.get_config()) self.assertAllEqual(network.get_config(), new_network.get_config())
class Embeddings(tf.keras.Model):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.inputs = [
tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_word_ids"),
tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name="input_mask")
]
self.attention_mask = layers.SelfAttentionMask()
self.embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=hidden_size,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name="word_embeddings")
def call(self, inputs):
word_ids, mask = inputs
word_embeddings = self.embedding_layer(word_ids)
return word_embeddings, self.attention_mask([word_embeddings, mask])
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
...@@ -334,20 +356,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): ...@@ -334,20 +356,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
# Build an embedding network to swap in for the default network. This one # Build an embedding network to swap in for the default network. This one
# will have 2 inputs (mask and word_ids) instead of 3, and won't use # will have 2 inputs (mask and word_ids) instead of 3, and won't use
# positional embeddings. # positional embeddings.
network = Embeddings(vocab_size, hidden_size)
word_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name="input_word_ids")
mask = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name="input_mask")
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=hidden_size,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name="word_embeddings")
word_embeddings = embedding_layer(word_ids)
attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])
network = tf.keras.Model([word_ids, mask],
[word_embeddings, attention_mask])
hidden_cfg = { hidden_cfg = {
"num_attention_heads": "num_attention_heads":
...@@ -371,8 +380,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): ...@@ -371,8 +380,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cls=network, embedding_cls=network)
embedding_data=embedding_layer.embeddings)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -390,11 +398,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): ...@@ -390,11 +398,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
mask_data = np.random.randint(2, size=(batch_size, sequence_length)) mask_data = np.random.randint(2, size=(batch_size, sequence_length))
_ = model.predict([word_id_data, mask_data]) _ = model.predict([word_id_data, mask_data])
# Test that we can get the embedding data that we passed to the object. This
# is necessary to support standard language model training.
self.assertIs(embedding_layer.embeddings,
test_network.get_embedding_table())
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment