Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
afd5579f
Commit
afd5579f
authored
Jul 22, 2020
by
Kaushik Shivakumar
Browse files
Merge remote-tracking branch 'upstream/master' into context_tf2
parents
dcd96e02
567bd18d
Changes
89
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
537 additions
and
424 deletions
+537
-424
official/benchmark/retinanet_benchmark.py
official/benchmark/retinanet_benchmark.py
+38
-0
official/benchmark/transformer_benchmark.py
official/benchmark/transformer_benchmark.py
+72
-94
official/colab/fine_tuning_bert.ipynb
official/colab/fine_tuning_bert.ipynb
+89
-62
official/nlp/bert/model_saving_utils.py
official/nlp/bert/model_saving_utils.py
+0
-4
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+52
-35
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+14
-8
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+150
-137
official/nlp/modeling/layers/attention_test.py
official/nlp/modeling/layers/attention_test.py
+20
-15
official/nlp/modeling/layers/multi_channel_attention.py
official/nlp/modeling/layers/multi_channel_attention.py
+32
-14
official/nlp/modeling/layers/multi_channel_attention_test.py
official/nlp/modeling/layers/multi_channel_attention_test.py
+5
-1
official/nlp/modeling/layers/position_embedding.py
official/nlp/modeling/layers/position_embedding.py
+0
-1
official/nlp/modeling/layers/rezero_transformer.py
official/nlp/modeling/layers/rezero_transformer.py
+2
-2
official/nlp/modeling/layers/talking_heads_attention.py
official/nlp/modeling/layers/talking_heads_attention.py
+7
-7
official/nlp/modeling/layers/talking_heads_attention_test.py
official/nlp/modeling/layers/talking_heads_attention_test.py
+10
-8
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+13
-9
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+2
-3
official/nlp/modeling/layers/transformer_scaffold_test.py
official/nlp/modeling/layers/transformer_scaffold_test.py
+2
-2
official/nlp/modeling/layers/transformer_test.py
official/nlp/modeling/layers/transformer_test.py
+4
-1
official/nlp/modeling/networks/encoder_scaffold_test.py
official/nlp/modeling/networks/encoder_scaffold_test.py
+24
-21
official/nlp/modeling/ops/__init__.py
official/nlp/modeling/ops/__init__.py
+1
-0
No files found.
official/benchmark/retinanet_benchmark.py
View file @
afd5579f
...
...
@@ -271,6 +271,44 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS
.
strategy_type
=
'tpu'
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
@
flagsaver
.
flagsaver
def
benchmark_4x4_tpu_coco
(
self
):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self
.
_setup
()
params
=
self
.
_params
()
params
[
'train'
][
'batch_size'
]
=
256
params
[
'train'
][
'total_steps'
]
=
469
# One epoch.
params
[
'train'
][
'iterations_per_loop'
]
=
500
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_4x4_tpu_coco'
)
FLAGS
.
strategy_type
=
'tpu'
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
@
flagsaver
.
flagsaver
def
benchmark_2x2_tpu_coco_mlir
(
self
):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self
.
_setup
()
params
=
self
.
_params
()
params
[
'train'
][
'batch_size'
]
=
64
params
[
'train'
][
'total_steps'
]
=
1875
# One epoch.
params
[
'train'
][
'iterations_per_loop'
]
=
500
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_2x2_tpu_coco_mlir'
)
FLAGS
.
strategy_type
=
'tpu'
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
@
flagsaver
.
flagsaver
def
benchmark_4x4_tpu_coco_mlir
(
self
):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self
.
_setup
()
params
=
self
.
_params
()
params
[
'train'
][
'batch_size'
]
=
256
params
[
'train'
][
'total_steps'
]
=
469
# One epoch.
params
[
'train'
][
'iterations_per_loop'
]
=
500
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_4x4_tpu_coco_mlir'
)
FLAGS
.
strategy_type
=
'tpu'
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
@
flagsaver
.
flagsaver
def
benchmark_2x2_tpu_spinenet_coco
(
self
):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
...
...
official/benchmark/transformer_benchmark.py
View file @
afd5579f
...
...
@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
from
official.nlp.transformer
import
transformer_main
as
transformer_main
from
official.utils.flags
import
core
as
flags_core
TPU_DATA_DIR
=
'gs://mlcompass-data/transformer'
GPU_DATA_DIR
=
os
.
getenv
(
'TMPDIR'
)
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
FLAGS
=
flags
.
FLAGS
...
...
@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark):
Code under test for the Transformer Keras models report the same data and
require the same FLAG setup.
"""
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
root_data_dir
=
None
,
flag_methods
=
None
,
tpu
=
None
):
self
.
_set_data_files
(
root_data_dir
=
root_data_dir
)
if
default_flags
is
None
:
default_flags
=
{}
default_flags
[
'data_dir'
]
=
self
.
train_data_dir
default_flags
[
'vocab_file'
]
=
self
.
vocab_file
super
(
TransformerBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
default_flags
,
flag_methods
=
flag_methods
,
tpu
=
tpu
)
def
_set_data_files
(
self
,
root_data_dir
=
None
,
tpu_run
=
False
):
"""Sets train_data_dir, vocab_file, bleu_source and bleu_ref."""
# Use remote storage for TPU, remote storage for GPU if defined, else
# use environment provided root_data_dir.
if
tpu_run
:
root_data_dir
=
TPU_DATA_DIR
elif
GPU_DATA_DIR
is
not
None
:
root_data_dir
=
GPU_DATA_DIR
root_data_dir
=
root_data_dir
if
root_data_dir
else
''
self
.
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
)
self
.
vocab_file
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
,
'vocab.ende.32768'
)
self
.
bleu_source
=
os
.
path
.
join
(
root_data_dir
,
EN2DE_2014_BLEU_DATA_DIR_NAME
,
'newstest2014.en'
)
self
.
bleu_ref
=
os
.
path
.
join
(
root_data_dir
,
EN2DE_2014_BLEU_DATA_DIR_NAME
,
'newstest2014.de'
)
if
default_flags
is
None
:
default_flags
=
{}
default_flags
[
'data_dir'
]
=
self
.
train_data_dir
default_flags
[
'vocab_file'
]
=
self
.
vocab_file
super
(
TransformerBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
default_flags
,
flag_methods
=
flag_methods
,
tpu
=
tpu
)
def
_set_data_file_flags
(
self
):
"""Sets the FLAGS for the data files."""
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
@
benchmark_wrappers
.
enable_runtime_flags
def
_run_and_report_benchmark
(
self
,
...
...
@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
1
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'base'
FLAGS
.
batch_size
=
2048
FLAGS
.
train_steps
=
1000
...
...
@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
1
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'base'
FLAGS
.
batch_size
=
4096
FLAGS
.
train_steps
=
100000
...
...
@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'base'
FLAGS
.
batch_size
=
4096
*
8
FLAGS
.
train_steps
=
100000
...
...
@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'base'
FLAGS
.
batch_size
=
4096
*
8
FLAGS
.
train_steps
=
100000
...
...
@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Iterations are not epochs, an iteration is a number of steps between evals.
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
train_steps
=
20000
*
12
...
...
@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
static_batch
=
True
...
...
@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
not epochs, an iteration is a number of steps between evals.
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
dtype
=
'fp16'
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
train_steps
=
20000
*
12
...
...
@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
train_steps
=
20000
*
12
...
...
@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
dtype
=
'fp16'
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
static_batch
=
True
...
...
@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
dtype
=
'fp16'
FLAGS
.
enable_xla
=
True
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
static_batch
=
True
...
...
@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
root_data_dir
=
root_data_dir
,
batch_per_gpu
=
3072
,
tpu
=
tpu
)
def
benchmark_2x2_tpu
(
self
):
"""Port of former snaggletooth transformer_big model on 2x2."""
self
.
_setup
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu'
)
def
_set_df_common
(
self
):
self
.
_set_data_files
(
tpu_run
=
True
)
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
FLAGS
.
distribution_strategy
=
'tpu'
FLAGS
.
padded_decode
=
True
FLAGS
.
train_steps
=
300
FLAGS
.
log_steps
=
150
FLAGS
.
steps_between_evals
=
150
FLAGS
.
distribution_strategy
=
'tpu'
FLAGS
.
static_batch
=
True
FLAGS
.
use_ctl
=
True
FLAGS
.
batch_size
=
6144
FLAGS
.
enable_checkpointing
=
False
FLAGS
.
max_length
=
64
FLAGS
.
decode_batch_size
=
32
FLAGS
.
decode_max_length
=
97
FLAGS
.
padded_decode
=
True
FLAGS
.
enable_checkpointing
=
False
def
benchmark_2x2_tpu
(
self
):
"""Port of former snaggletooth transformer_big model on 2x2."""
self
.
_setup
()
self
.
_set_df_common
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu'
)
FLAGS
.
batch_size
=
6144
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
)
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
def
benchmark_2x2_tpu_mlir
(
self
):
"""Run transformer_big model on 2x2 with the MLIR Bridge enabled."""
self
.
_setup
()
self
.
_set_df_common
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu_mlir'
)
FLAGS
.
batch_size
=
6144
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
...
...
@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def
benchmark_4x4_tpu
(
self
):
"""Port of former GCP transformer_big model on 4x4."""
self
.
_setup
()
self
.
_set_df_common
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu'
)
FLAGS
.
train_steps
=
300
FLAGS
.
log_steps
=
150
FLAGS
.
steps_between_evals
=
150
FLAGS
.
distribution_strategy
=
'tpu'
FLAGS
.
static_batch
=
True
FLAGS
.
use_ctl
=
True
FLAGS
.
batch_size
=
24576
FLAGS
.
max_length
=
64
FLAGS
.
decode_batch_size
=
32
FLAGS
.
decode_max_length
=
97
FLAGS
.
padded_decode
=
True
FLAGS
.
enable_checkpointing
=
False
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
...
...
@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def
benchmark_4x4_tpu_mlir
(
self
):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self
.
_setup
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu'
)
FLAGS
.
train_steps
=
300
FLAGS
.
log_steps
=
150
FLAGS
.
steps_between_evals
=
150
FLAGS
.
distribution_strategy
=
'tpu'
FLAGS
.
static_batch
=
True
FLAGS
.
use_ctl
=
True
self
.
_set_df_common
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu_mlir'
)
FLAGS
.
batch_size
=
24576
FLAGS
.
max_length
=
64
FLAGS
.
decode_batch_size
=
32
FLAGS
.
decode_max_length
=
97
FLAGS
.
padded_decode
=
True
FLAGS
.
enable_checkpointing
=
False
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
(
...
...
official/colab/fine_tuning_bert.ipynb
View file @
afd5579f
...
...
@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"cellView": "form",
"colab": {},
...
...
@@ -104,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -128,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -185,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -204,12 +204,12 @@
"id": "9uFskufsR2LT"
},
"source": [
"You can get a pre-trained BERT encoder from TensorFlow Hub
here
:"
"You can get a pre-trained BERT encoder from
[
TensorFlow Hub
](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2)
:"
]
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -252,7 +252,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -267,7 +267,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -290,7 +290,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -313,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -336,7 +336,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -376,7 +376,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -404,7 +404,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -446,7 +446,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -469,7 +469,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -490,7 +490,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -514,7 +514,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -562,7 +562,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -587,7 +587,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -617,7 +617,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -661,7 +661,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -691,7 +691,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -737,7 +737,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -769,7 +769,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -793,7 +793,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -816,7 +816,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -845,7 +845,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -870,7 +870,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -908,7 +908,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -943,7 +943,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -986,7 +986,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1023,7 +1023,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1055,7 +1055,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1071,7 +1071,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1096,7 +1096,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1110,7 +1110,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1176,7 +1176,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1201,7 +1201,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1240,7 +1240,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1273,7 +1273,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1306,7 +1306,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1351,7 +1351,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1379,7 +1379,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1406,17 +1406,44 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "
lo6479At4sP1
"
"id": "
GDWrHm0BGpbX
"
},
"outputs": [],
"source": [
"# Note: 350MB download.\n",
"import tensorflow_hub as hub\n",
"hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n",
"import tensorflow_hub as hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5"
},
"outputs": [],
"source": [
"hub_model_name = \"bert_en_uncased_L-12_H-768_A-12\" #@param [\"bert_en_uncased_L-24_H-1024_A-16\", \"bert_en_wwm_cased_L-24_H-1024_A-16\", \"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_wwm_uncased_L-24_H-1024_A-16\", \"bert_en_cased_L-24_H-1024_A-16\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_zh_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}\",\n",
" trainable=True)\n",
"\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
]
...
...
@@ -1433,7 +1460,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1466,7 +1493,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1491,7 +1518,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1504,7 +1531,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1545,7 +1572,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1569,7 +1596,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1592,7 +1619,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1617,7 +1644,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1643,7 +1670,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1661,7 +1688,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1688,7 +1715,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1714,7 +1741,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1733,7 +1760,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1761,7 +1788,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
@@ -1795,7 +1822,7 @@
},
{
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"colab": {},
"colab_type": "code",
...
...
official/nlp/bert/model_saving_utils.py
View file @
afd5579f
...
...
@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise
ValueError
(
'model must be a tf.keras.Model object.'
)
if
checkpoint_dir
:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if
restore_model_using_load_weights
:
model_weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
'checkpoint'
)
assert
tf
.
io
.
gfile
.
exists
(
model_weight_path
)
model
.
load_weights
(
model_weight_path
)
# tf.train.Checkpoint API was used via custom training loop logic.
else
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
...
...
official/nlp/data/classifier_data_lib.py
View file @
afd5579f
...
...
@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor):
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
# Only the test set has a header
for
i
,
line
in
enumerate
(
lines
):
# Only the test set has a header
.
if
set_type
==
"test"
and
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
...
@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor):
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
def
__init__
(
self
,
mnli_type
=
"matched"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
MnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
mnli_type
not
in
(
"matched"
,
"mismatched"
):
raise
ValueError
(
"Invalid `mnli_type`: %s"
%
mnli_type
)
self
.
mnli_type
=
mnli_type
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_mismatched"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_mismatched.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor):
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
...
...
@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor):
return
"MRPC"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
...
@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
language
,
train_tsv
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
...
@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"dev_2k.tsv"
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
...
@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"test_2k.tsv"
))[
1
:]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
...
@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor):
return
"QNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
1
)
...
...
@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor):
return
"QQP"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
...
...
@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor):
return
"RTE"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
...
...
@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor):
return
"SST-2"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
...
@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor):
return
"STS-B"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
...
...
@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return
"TFDS_"
+
self
.
dataset_name
def
_create_examples
(
self
,
split_name
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
if
split_name
not
in
self
.
dataset
:
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
...
...
@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor):
return
"WNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
...
...
@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor):
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
...
@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"dev-%d"
%
i
...
...
@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"test-%d"
%
i
...
...
@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
...
@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor):
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
...
@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
...
@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor):
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
...
@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
...
@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
...
@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
for
ex_index
,
example
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
logging
.
info
(
"Writing example %d of %d"
,
ex_index
,
len
(
examples
))
...
...
official/nlp/data/create_finetuning_data.py
View file @
afd5579f
...
...
@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
# XNLI task specific flag.
# MNLI task-specific flag.
flags
.
DEFINE_enum
(
"mnli_type"
,
"matched"
,
[
"matched"
,
"mismatched"
],
"The type of MNLI dataset."
)
# XNLI task-specific flag.
flags
.
DEFINE_string
(
"xnli_language"
,
"en"
,
"Language of training data for XN
I
L task. If the value is 'all', the data "
"Language of training data for XNL
I
task. If the value is 'all', the data "
"of all languages will be used for training."
)
# PAWS-X task
specific flag.
# PAWS-X task
-
specific flag.
flags
.
DEFINE_string
(
"pawsx_language"
,
"en"
,
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of traini
n
g data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training."
)
# Retrieva task
specific flags
# Retrieva
l
task
-
specific flags
.
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
# Tagging task
specific flags
# Tagging task
-
specific flags
.
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
"The name of BERT tagging (token classification) task."
)
# BERT Squad task
specific flags.
# BERT Squad task
-
specific flags.
flags
.
DEFINE_string
(
"squad_data_file"
,
None
,
"The input data file in for generating training data for BERT squad task."
)
...
...
@@ -179,7 +184,8 @@ def generate_classifier_dataset():
"cola"
:
classifier_data_lib
.
ColaProcessor
,
"mnli"
:
classifier_data_lib
.
MnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
MnliProcessor
,
mnli_type
=
FLAGS
.
mnli_type
),
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
...
...
official/nlp/modeling/layers/attention.py
View file @
afd5579f
...
...
@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX
=
string
.
ascii_lowercase
def
_build_attention_equation
(
qkv_
rank
,
attn_axes
):
def
_build_attention_equation
(
rank
,
attn_axes
):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
...
...
@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels)
Args:
qkv_
rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns:
Einsum equations.
"""
target_notation
=
_CHR_IDX
[:
qkv_
rank
]
target_notation
=
_CHR_IDX
[:
rank
]
# `batch_dims` includes the head dim.
batch_dims
=
tuple
(
np
.
delete
(
range
(
qkv_
rank
),
attn_axes
+
(
qkv_
rank
-
1
,)))
letter_offset
=
qkv_
rank
batch_dims
=
tuple
(
np
.
delete
(
range
(
rank
),
attn_axes
+
(
rank
-
1
,)))
letter_offset
=
rank
source_notation
=
""
for
i
in
range
(
qkv_
rank
):
if
i
in
batch_dims
or
i
==
qkv_
rank
-
1
:
for
i
in
range
(
rank
):
if
i
in
batch_dims
or
i
==
rank
-
1
:
source_notation
+=
target_notation
[
i
]
else
:
source_notation
+=
_CHR_IDX
[
letter_offset
]
...
...
@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head
attention
scores as an additional output argument.
return_attention_scores: bool, if `True`, returns the multi-head
attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
...
...
@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def
__init__
(
self
,
...
...
@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self
.
_attention_axes
=
(
attention_axes
,)
else
:
self
.
_attention_axes
=
attention_axes
self
.
_built_from_signature
=
False
def
get_config
(
self
):
config
=
{
...
...
@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config
=
super
(
MultiHeadAttention
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
inputs_len
=
len
(
input_shape
)
if
inputs_len
>
3
or
inputs_len
<
2
:
raise
ValueError
(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d"
%
inputs_len
)
tensor_shapes
=
tf
.
nest
.
map_structure
(
tf
.
TensorShape
,
input_shape
)
query_shape
=
tensor_shapes
[
0
]
value_shape
=
tensor_shapes
[
1
]
key_shape
=
tensor_shapes
[
2
]
if
inputs_len
==
3
else
value_shape
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: query tensor or TensorShape.
value: value tensor or TensorShape.
key: key tensor or TensorShape.
"""
self
.
_built_from_signature
=
True
if
hasattr
(
query
,
"shape"
):
query_shape
=
tf
.
TensorShape
(
query
.
shape
)
else
:
query_shape
=
query
if
hasattr
(
value
,
"shape"
):
value_shape
=
tf
.
TensorShape
(
value
.
shape
)
else
:
value_shape
=
value
if
key
is
None
:
key_shape
=
value_shape
elif
hasattr
(
key
,
"shape"
):
key_shape
=
tf
.
TensorShape
(
key
.
shape
)
else
:
key_shape
=
key
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
...
...
@@ -271,7 +293,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
with
tf
.
init_scope
():
free_dims
=
query_shape
.
rank
-
1
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
1
,
output_dims
=
2
)
...
...
@@ -302,9 +324,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
**
common_kwargs
)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
it
#
support mult-head einsum computations.
self
.
_
build_attention
(
output_rank
)
# These computations could be wrapped into the keras attention layer once
# it
support mult-head einsum computations.
self
.
build_attention
(
output_rank
)
if
self
.
_output_shape
:
if
not
isinstance
(
self
.
_output_shape
,
collections
.
abc
.
Sized
):
output_shape
=
[
self
.
_output_shape
]
...
...
@@ -320,35 +342,30 @@ class MultiHeadAttention(tf.keras.layers.Layer):
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"attention_output"
,
**
common_kwargs
)
super
(
MultiHeadAttention
,
self
).
build
(
input_shape
)
def
_
build_attention
(
self
,
qkv_
rank
):
def
build_attention
(
self
,
rank
):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `
_
compute_attention` to
This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product
attention.
Args:
qkv_
rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
"""
if
self
.
_attention_axes
is
None
:
self
.
_attention_axes
=
tuple
(
range
(
1
,
qkv_
rank
-
2
))
self
.
_attention_axes
=
tuple
(
range
(
1
,
rank
-
2
))
else
:
self
.
_attention_axes
=
tuple
(
self
.
_attention_axes
)
self
.
_dot_product_equation
,
self
.
_combine_equation
,
attn_scores_rank
=
(
_build_attention_equation
(
qkv_
rank
,
attn_axes
=
self
.
_attention_axes
))
_build_attention_equation
(
rank
,
attn_axes
=
self
.
_attention_axes
))
norm_axes
=
tuple
(
range
(
attn_scores_rank
-
len
(
self
.
_attention_axes
),
attn_scores_rank
))
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
],
normalization_axes
=
norm_axes
)
self
.
_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
_compute_attention
(
self
,
query_tensor
,
key_tensor
,
value_tensor
,
attention_mask
=
None
):
def
compute_attention
(
self
,
query
,
key
,
value
,
attention_mask
=
None
):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
...
...
@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation.
Args:
query
_tensor
: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key
_tensor
: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value
_tensor
: Projected value `Tensor` of shape `[B, T, N, value_size]`.
query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
...
...
@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
tf
.
multiply
(
query
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
)
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
...
...
@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H]
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_scores_dropout
,
value
_tensor
)
attention_scores_dropout
,
value
)
return
attention_output
,
attention_scores
def
call
(
self
,
inputs
,
attention_mask
=
None
):
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
):
"""Implements the forward pass.
Size glossary:
...
...
@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target.
Args:
inputs: List of the following tensors:
* query: Query `Tensor` of shape `[B, T, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`.
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
use `value` for both `key` and `value`, which is the most common case.
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
...
...
@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention
axes.
"""
inputs_len
=
len
(
inputs
)
if
inputs_len
>
3
or
inputs_len
<
2
:
raise
ValueError
(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d"
%
inputs_len
)
query
=
inputs
[
0
]
value
=
inputs
[
1
]
key
=
inputs
[
2
]
if
inputs_len
==
3
else
value
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# N = `num_attention_heads`
# H = `size_per_head`
# `query
_tensor
` = [B, T, N ,H]
query
_tensor
=
self
.
_query_dense
(
query
)
# `query` = [B, T, N ,H]
query
=
self
.
_query_dense
(
query
)
# `key
_tensor
` = [B, S, N, H]
key
_tensor
=
self
.
_key_dense
(
key
)
# `key` = [B, S, N, H]
key
=
self
.
_key_dense
(
key
)
# `value
_tensor
` = [B, S, N, H]
value
_tensor
=
self
.
_value_dense
(
value
)
# `value` = [B, S, N, H]
value
=
self
.
_value_dense
(
value
)
attention_output
,
attention_scores
=
self
.
_
compute_attention
(
query
_tensor
,
key_tensor
,
value_tensor
,
attention_mask
)
attention_output
,
attention_scores
=
self
.
compute_attention
(
query
,
key
,
value
,
attention_mask
)
attention_output
=
self
.
_output_dense
(
attention_output
)
if
self
.
_return_attention_scores
:
...
...
@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer.
"""
def
_update_cache
(
self
,
key
_tensor
,
value_tensor
,
cache
,
decode_loop_step
):
def
_update_cache
(
self
,
key
,
value
,
cache
,
decode_loop_step
):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
if
decode_loop_step
is
not
None
:
# TPU special case.
key_seq_dim
=
cache
[
"key"
].
shape
.
as_list
()[
1
]
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
key_seq_dim
,
dtype
=
key
_tensor
.
dtype
),
tf
.
one_hot
(
decode_loop_step
,
key_seq_dim
,
dtype
=
key
.
dtype
),
[
1
,
key_seq_dim
,
1
,
1
])
key
_tensor
=
cache
[
"key"
]
+
key
_tensor
*
indices
key
=
cache
[
"key"
]
+
key
*
indices
value_seq_dim
=
cache
[
"value"
].
shape
.
as_list
()[
1
]
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
value_seq_dim
,
dtype
=
value
_tensor
.
dtype
),
tf
.
one_hot
(
decode_loop_step
,
value_seq_dim
,
dtype
=
value
.
dtype
),
[
1
,
value_seq_dim
,
1
,
1
])
value
_tensor
=
cache
[
"value"
]
+
value
_tensor
*
indices
value
=
cache
[
"value"
]
+
value
*
indices
else
:
key_tensor
=
tf
.
concat
(
[
tf
.
cast
(
cache
[
"key"
],
key_tensor
.
dtype
),
key_tensor
],
axis
=
1
)
value_tensor
=
tf
.
concat
(
[
tf
.
cast
(
cache
[
"value"
],
value_tensor
.
dtype
),
value_tensor
],
axis
=
1
)
key
=
tf
.
concat
([
tf
.
cast
(
cache
[
"key"
],
key
.
dtype
),
key
],
axis
=
1
)
value
=
tf
.
concat
([
tf
.
cast
(
cache
[
"value"
],
value
.
dtype
),
value
],
axis
=
1
)
# Update cache
cache
[
"key"
]
=
key
_tensor
cache
[
"value"
]
=
value
_tensor
cache
[
"key"
]
=
key
cache
[
"value"
]
=
value
return
key
_tensor
,
value_tensor
return
key
,
value
def
call
(
self
,
inputs
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
cache
=
None
,
decode_loop_step
=
None
):
from_tensor
=
inputs
[
0
]
to_tensor
=
inputs
[
1
]
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
...
...
@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query
_tensor
` = [B, F, N ,H]
query
_tensor
=
self
.
_query_dense
(
from_tensor
)
# `query` = [B, F, N ,H]
query
=
self
.
_query_dense
(
query
)
# `key
_tensor
` = [B, T, N, H]
key
_tensor
=
self
.
_key_dense
(
to_tensor
)
# `key` = [B, T, N, H]
key
=
self
.
_key_dense
(
key
)
# `value
_tensor
` = [B, T, N, H]
value
_tensor
=
self
.
_value_dense
(
to_tensor
)
# `value` = [B, T, N, H]
value
=
self
.
_value_dense
(
value
)
if
cache
:
key_tensor
,
value_tensor
=
self
.
_update_cache
(
key_tensor
,
value_tensor
,
cache
,
decode_loop_step
)
key
,
value
=
self
.
_update_cache
(
key
,
value
,
cache
,
decode_loop_step
)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
...
...
@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores
=
self
.
_dropout_layer
(
attention_scores
)
# `context_layer` = [B, F, N, H]
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_scores
,
value
_tensor
)
value
)
attention_output
=
self
.
_output_dense
(
attention_output
)
if
self
.
_return_attention_scores
:
return
attention_output
,
attention_scores
,
cache
...
...
official/nlp/modeling/layers/attention_test.py
View file @
afd5579f
...
...
@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
(
[
query
,
value
]
)
output
=
test_layer
(
query
=
query
,
value
=
value
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
def
test_non_masked_self_attention
(
self
):
...
...
@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
12
,
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
,
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_attention_scores
(
self
):
...
...
@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
,
coef
=
test_layer
(
[
query
,
query
]
)
output
,
coef
=
test_layer
(
query
,
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
...
...
@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
(
[
query
,
value
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_tensor
)
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
...
...
@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
output
=
test_layer
(
[
query
,
value
,
key
],
mask_tensor
)
output
=
test_layer
(
query
,
value
=
value
,
key
=
key
,
attention_mask
=
mask_tensor
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
...
...
@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
,
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
@
parameterized
.
named_parameters
(
...
...
@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
output
=
test_layer
(
[
query
,
value
],
mask_data
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_data
)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
(
mask_shape
)
unmasked_output
=
test_layer
([
query
,
value
],
null_mask_data
)
unmasked_output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
null_mask_data
)
# Because one data is masked and one is not, the outputs should not be the
# same.
self
.
assertNotAllClose
(
output
,
unmasked_output
)
...
...
@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
,
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
...
...
@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
))
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
],
mask_data
,
cache
)
masked_output_data
,
cache
=
layer
(
query
=
from_data
,
value
=
from_data
,
attention_mask
=
mask_data
,
cache
=
cache
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
# Tests inputs without cache.
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
])
masked_output_data
,
cache
=
layer
(
query
=
from_data
,
value
=
from_data
,
attention_mask
=
mask_data
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertIsNone
(
cache
)
...
...
@@ -243,9 +246,11 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
),
dtype
=
np
.
int32
)
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
],
mask_data
,
cache
,
masked_output_data
,
cache
=
layer
(
query
=
from_data
,
value
=
from_data
,
attention_mask
=
mask_data
,
cache
=
cache
,
decode_loop_step
=
decode_loop_step
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
...
...
official/nlp/modeling/layers/multi_channel_attention.py
View file @
afd5579f
...
...
@@ -110,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
class
MultiChannelAttention
(
attention
.
MultiHeadAttention
):
"""Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple
cross-attention target sequences.
Introduced in, [Generating Representative Headlines for News Stories
](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def
_
build_attention
(
self
,
qkv_
rank
):
super
(
MultiChannelAttention
,
self
).
_
build_attention
(
qkv_
rank
)
def
build_attention
(
self
,
rank
):
super
(
MultiChannelAttention
,
self
).
build_attention
(
rank
)
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
def
call
(
self
,
inputs
,
attention_mask
=
None
):
from_tensor
=
inputs
[
0
]
to_tensor
=
inputs
[
1
]
doc_attention_probs
=
inputs
[
2
]
def
call
(
self
,
query
,
value
,
key
=
None
,
context_attention_weights
=
None
,
attention_mask
=
None
):
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
,
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# Scalar dimensions referenced here:
# B = batch size (number of stories)
# A = num_docs (number of docs)
# F =
`from_tensor`
sequence length
# T =
`to_tensor`
sequence length
# F =
target
sequence length
# T =
source
sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor
=
self
.
_query_dense
(
from_tensor
)
query_tensor
=
self
.
_query_dense
(
query
)
# `key_tensor` = [B, A, T, N, H]
key_tensor
=
self
.
_key_dense
(
to_tensor
)
key_tensor
=
self
.
_key_dense
(
key
)
# `value_tensor` = [B, A, T, N, H]
value_tensor
=
self
.
_value_dense
(
to_tensor
)
value_tensor
=
self
.
_value_dense
(
value
)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
...
...
@@ -156,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# `context_layer` = [B, F, N, H]
context_layer
=
tf
.
einsum
(
"BANFT,BATNH->BAFNH"
,
attention_probs
,
value_tensor
)
attention_output
=
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
doc
_attention_
prob
s
,
attention_output
=
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
context
_attention_
weight
s
,
context_layer
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
official/nlp/modeling/layers/multi_channel_attention_test.py
View file @
afd5579f
...
...
@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_docs
,
4
,
2
))
doc_probs
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_heads
,
4
,
num_docs
)).
astype
(
float
)
outputs
=
attention_layer
([
from_data
,
to_data
,
doc_probs
],
mask_data
)
outputs
=
attention_layer
(
query
=
from_data
,
value
=
to_data
,
context_attention_weights
=
doc_probs
,
attention_mask
=
mask_data
)
self
.
assertEqual
(
outputs
.
shape
,
(
3
,
4
,
8
))
...
...
official/nlp/modeling/layers/position_embedding.py
View file @
afd5579f
...
...
@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size"
:
self
.
_hidden_size
,
"min_timescale"
:
self
.
_min_timescale
,
"max_timescale"
:
self
.
_max_timescale
,
"length"
:
self
.
_length
,
}
base_config
=
super
(
RelativePositionEmbedding
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
...
official/nlp/modeling/layers/rezero_transformer.py
View file @
afd5579f
...
...
@@ -213,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
target_tensor
=
input_tensor
attention_inputs
=
[
target_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
if
self
.
_use_layer_norm
:
...
...
official/nlp/modeling/layers/talking_heads_attention.py
View file @
afd5579f
...
...
@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels.
"""
def
_
build_attention
(
self
,
qkv_rank
):
def
build_attention
(
self
,
qkv_rank
):
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
...
...
@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args:
qkv_rank: the rank of query, key, value tensors after projection.
"""
super
(
TalkingHeadsAttention
,
self
).
_
build_attention
(
qkv_rank
)
super
(
TalkingHeadsAttention
,
self
).
build_attention
(
qkv_rank
)
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
...
...
@@ -103,7 +103,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype
=
self
.
dtype
,
trainable
=
True
)
def
_
compute_attention
(
self
,
def
compute_attention
(
self
,
query_tensor
,
key_tensor
,
value_tensor
,
...
...
official/nlp/modeling/layers/talking_heads_attention_test.py
View file @
afd5579f
...
...
@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
(
[
query
,
value
]
)
output
=
test_layer
(
query
=
query
,
value
=
value
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
def
test_non_masked_self_attention
(
self
):
...
...
@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_attention_scores
(
self
):
...
...
@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
,
coef
=
test_layer
(
[
query
,
query
]
)
output
,
coef
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
...
...
@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
(
[
query
,
value
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_tensor
)
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
...
...
@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
output
=
test_layer
([
query
,
value
,
key
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
key
=
key
,
attention_mask
=
mask_tensor
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
...
...
@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
@
parameterized
.
named_parameters
(
...
...
@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
output
=
test_layer
(
[
query
,
value
],
mask_data
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_data
)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
(
mask_shape
)
unmasked_output
=
test_layer
([
query
,
value
],
null_mask_data
)
unmasked_output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
null_mask_data
)
# Because one data is masked and one is not, the outputs should not be the
# same.
self
.
assertNotAllClose
(
output
,
unmasked_output
)
...
...
official/nlp/modeling/layers/transformer.py
View file @
afd5579f
...
...
@@ -120,7 +120,9 @@ class Transformer(tf.keras.layers.Layer):
name
=
"self_attention"
,
**
common_kwargs
)
# pylint: disable=protected-access
self
.
_attention_layer
.
build
([
input_tensor_shape
]
*
3
)
# Temporarily handling for checkpoint compatible changes.
self
.
_attention_layer
.
_build_from_signature
(
query
=
input_tensor_shape
,
value
=
input_tensor_shape
)
self
.
_attention_output_dense
=
self
.
_attention_layer
.
_output_dense
# pylint: enable=protected-access
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
...
...
@@ -202,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
target_tensor
=
input_tensor
attention_inputs
=
[
target_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
...
...
@@ -382,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d"
%
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
self_attention_inputs
=
[
input_tensor
,
input_tensor
]
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_inputs
,
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
self_attention_mask
,
cache
=
cache
,
decode_loop_step
=
decode_loop_step
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_layer_norm
(
input_tensor
+
self_attention_output
)
cross_attn_inputs
=
[
self_attention_output
,
memory
]
cross_attn_inputs
=
dict
(
query
=
self_attention_output
,
value
=
memory
,
attention_mask
=
attention_mask
)
if
self
.
multi_channel_cross_attention
:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs
.
append
(
inputs
[
-
1
]
)
attention_output
=
self
.
encdec_attention
(
cross_attn_inputs
,
attention_mask
)
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
)
...
...
official/nlp/modeling/layers/transformer_scaffold.py
View file @
afd5579f
...
...
@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else
:
input_tensor
,
attention_mask
=
(
inputs
,
None
)
attention_inputs
=
[
input_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_layer
(
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
)
...
...
official/nlp/modeling/layers/transformer_scaffold_test.py
View file @
afd5579f
...
...
@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
def
call
(
self
,
inputs
,
attention_mask
=
None
):
def
call
(
self
,
query
,
value
,
attention_mask
=
None
):
self
.
list
.
append
(
True
)
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
inputs
,
attention_mask
=
attention_mask
)
query
,
value
,
attention_mask
=
attention_mask
)
def
get_config
(
self
):
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
...
...
official/nlp/modeling/layers/transformer_test.py
View file @
afd5579f
...
...
@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_
=
new_layer
([
input_data
,
mask_data
])
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.003
)
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
...
...
official/nlp/modeling/networks/encoder_scaffold_test.py
View file @
afd5579f
...
...
@@ -323,6 +323,28 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
class
Embeddings
(
tf
.
keras
.
Model
):
def
__init__
(
self
,
vocab_size
,
hidden_size
):
super
().
__init__
()
self
.
inputs
=
[
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
"input_mask"
)
]
self
.
attention_mask
=
layers
.
SelfAttentionMask
()
self
.
embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
hidden_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
name
=
"word_embeddings"
)
def
call
(
self
,
inputs
):
word_ids
,
mask
=
inputs
word_embeddings
=
self
.
embedding_layer
(
word_ids
)
return
word_embeddings
,
self
.
attention_mask
([
word_embeddings
,
mask
])
@
keras_parameterized
.
run_all_keras_modes
class
EncoderScaffoldEmbeddingNetworkTest
(
keras_parameterized
.
TestCase
):
...
...
@@ -334,20 +356,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
# Build an embedding network to swap in for the default network. This one
# will have 2 inputs (mask and word_ids) instead of 3, and won't use
# positional embeddings.
word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
)
mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
,
name
=
"input_mask"
)
embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
hidden_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
name
=
"word_embeddings"
)
word_embeddings
=
embedding_layer
(
word_ids
)
attention_mask
=
layers
.
SelfAttentionMask
()([
word_embeddings
,
mask
])
network
=
tf
.
keras
.
Model
([
word_ids
,
mask
],
[
word_embeddings
,
attention_mask
])
network
=
Embeddings
(
vocab_size
,
hidden_size
)
hidden_cfg
=
{
"num_attention_heads"
:
...
...
@@ -371,8 +380,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
hidden_cfg
=
hidden_cfg
,
embedding_cls
=
network
,
embedding_data
=
embedding_layer
.
embeddings
)
embedding_cls
=
network
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
...
@@ -390,11 +398,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
_
=
model
.
predict
([
word_id_data
,
mask_data
])
# Test that we can get the embedding data that we passed to the object. This
# is necessary to support standard language model training.
self
.
assertIs
(
embedding_layer
.
embeddings
,
test_network
.
get_embedding_table
())
def
test_serialize_deserialize
(
self
):
hidden_size
=
32
sequence_length
=
21
...
...
official/nlp/modeling/ops/__init__.py
0 → 100644
View file @
afd5579f
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment