Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
afd5579f
Commit
afd5579f
authored
Jul 22, 2020
by
Kaushik Shivakumar
Browse files
Merge remote-tracking branch 'upstream/master' into context_tf2
parents
dcd96e02
567bd18d
Changes
89
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
537 additions
and
424 deletions
+537
-424
official/benchmark/retinanet_benchmark.py
official/benchmark/retinanet_benchmark.py
+38
-0
official/benchmark/transformer_benchmark.py
official/benchmark/transformer_benchmark.py
+72
-94
official/colab/fine_tuning_bert.ipynb
official/colab/fine_tuning_bert.ipynb
+89
-62
official/nlp/bert/model_saving_utils.py
official/nlp/bert/model_saving_utils.py
+0
-4
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+52
-35
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+14
-8
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+150
-137
official/nlp/modeling/layers/attention_test.py
official/nlp/modeling/layers/attention_test.py
+20
-15
official/nlp/modeling/layers/multi_channel_attention.py
official/nlp/modeling/layers/multi_channel_attention.py
+32
-14
official/nlp/modeling/layers/multi_channel_attention_test.py
official/nlp/modeling/layers/multi_channel_attention_test.py
+5
-1
official/nlp/modeling/layers/position_embedding.py
official/nlp/modeling/layers/position_embedding.py
+0
-1
official/nlp/modeling/layers/rezero_transformer.py
official/nlp/modeling/layers/rezero_transformer.py
+2
-2
official/nlp/modeling/layers/talking_heads_attention.py
official/nlp/modeling/layers/talking_heads_attention.py
+7
-7
official/nlp/modeling/layers/talking_heads_attention_test.py
official/nlp/modeling/layers/talking_heads_attention_test.py
+10
-8
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+13
-9
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+2
-3
official/nlp/modeling/layers/transformer_scaffold_test.py
official/nlp/modeling/layers/transformer_scaffold_test.py
+2
-2
official/nlp/modeling/layers/transformer_test.py
official/nlp/modeling/layers/transformer_test.py
+4
-1
official/nlp/modeling/networks/encoder_scaffold_test.py
official/nlp/modeling/networks/encoder_scaffold_test.py
+24
-21
official/nlp/modeling/ops/__init__.py
official/nlp/modeling/ops/__init__.py
+1
-0
No files found.
official/benchmark/retinanet_benchmark.py
View file @
afd5579f
...
@@ -271,6 +271,44 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
...
@@ -271,6 +271,44 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS
.
strategy_type
=
'tpu'
FLAGS
.
strategy_type
=
'tpu'
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
@
flagsaver
.
flagsaver
def
benchmark_4x4_tpu_coco
(
self
):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self
.
_setup
()
params
=
self
.
_params
()
params
[
'train'
][
'batch_size'
]
=
256
params
[
'train'
][
'total_steps'
]
=
469
# One epoch.
params
[
'train'
][
'iterations_per_loop'
]
=
500
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_4x4_tpu_coco'
)
FLAGS
.
strategy_type
=
'tpu'
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
@
flagsaver
.
flagsaver
def
benchmark_2x2_tpu_coco_mlir
(
self
):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self
.
_setup
()
params
=
self
.
_params
()
params
[
'train'
][
'batch_size'
]
=
64
params
[
'train'
][
'total_steps'
]
=
1875
# One epoch.
params
[
'train'
][
'iterations_per_loop'
]
=
500
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_2x2_tpu_coco_mlir'
)
FLAGS
.
strategy_type
=
'tpu'
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
@
flagsaver
.
flagsaver
def
benchmark_4x4_tpu_coco_mlir
(
self
):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self
.
_setup
()
params
=
self
.
_params
()
params
[
'train'
][
'batch_size'
]
=
256
params
[
'train'
][
'total_steps'
]
=
469
# One epoch.
params
[
'train'
][
'iterations_per_loop'
]
=
500
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_4x4_tpu_coco_mlir'
)
FLAGS
.
strategy_type
=
'tpu'
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
@
flagsaver
.
flagsaver
@
flagsaver
.
flagsaver
def
benchmark_2x2_tpu_spinenet_coco
(
self
):
def
benchmark_2x2_tpu_spinenet_coco
(
self
):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
...
...
official/benchmark/transformer_benchmark.py
View file @
afd5579f
...
@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
...
@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
from
official.nlp.transformer
import
transformer_main
as
transformer_main
from
official.nlp.transformer
import
transformer_main
as
transformer_main
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
TPU_DATA_DIR
=
'gs://mlcompass-data/transformer'
GPU_DATA_DIR
=
os
.
getenv
(
'TMPDIR'
)
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark):
...
@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark):
Code under test for the Transformer Keras models report the same data and
Code under test for the Transformer Keras models report the same data and
require the same FLAG setup.
require the same FLAG setup.
"""
"""
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
root_data_dir
=
None
,
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
root_data_dir
=
None
,
flag_methods
=
None
,
tpu
=
None
):
flag_methods
=
None
,
tpu
=
None
):
self
.
_set_data_files
(
root_data_dir
=
root_data_dir
)
if
default_flags
is
None
:
default_flags
=
{}
default_flags
[
'data_dir'
]
=
self
.
train_data_dir
default_flags
[
'vocab_file'
]
=
self
.
vocab_file
super
(
TransformerBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
default_flags
,
flag_methods
=
flag_methods
,
tpu
=
tpu
)
def
_set_data_files
(
self
,
root_data_dir
=
None
,
tpu_run
=
False
):
"""Sets train_data_dir, vocab_file, bleu_source and bleu_ref."""
# Use remote storage for TPU, remote storage for GPU if defined, else
# use environment provided root_data_dir.
if
tpu_run
:
root_data_dir
=
TPU_DATA_DIR
elif
GPU_DATA_DIR
is
not
None
:
root_data_dir
=
GPU_DATA_DIR
root_data_dir
=
root_data_dir
if
root_data_dir
else
''
root_data_dir
=
root_data_dir
if
root_data_dir
else
''
self
.
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
self
.
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
)
TRANSFORMER_EN2DE_DATA_DIR_NAME
)
self
.
vocab_file
=
os
.
path
.
join
(
root_data_dir
,
self
.
vocab_file
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
,
'vocab.ende.32768'
)
'vocab.ende.32768'
)
self
.
bleu_source
=
os
.
path
.
join
(
root_data_dir
,
self
.
bleu_source
=
os
.
path
.
join
(
root_data_dir
,
EN2DE_2014_BLEU_DATA_DIR_NAME
,
EN2DE_2014_BLEU_DATA_DIR_NAME
,
'newstest2014.en'
)
'newstest2014.en'
)
self
.
bleu_ref
=
os
.
path
.
join
(
root_data_dir
,
self
.
bleu_ref
=
os
.
path
.
join
(
root_data_dir
,
EN2DE_2014_BLEU_DATA_DIR_NAME
,
EN2DE_2014_BLEU_DATA_DIR_NAME
,
'newstest2014.de'
)
'newstest2014.de'
)
if
default_flags
is
None
:
def
_set_data_file_flags
(
self
):
default_flags
=
{}
"""Sets the FLAGS for the data files."""
default_flags
[
'data_dir'
]
=
self
.
train_data_dir
FLAGS
.
data_dir
=
self
.
train_data_dir
default_flags
[
'vocab_file'
]
=
self
.
vocab_file
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
super
(
TransformerBenchmark
,
self
).
__init__
(
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
output_dir
=
output_dir
,
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
default_flags
=
default_flags
,
flag_methods
=
flag_methods
,
tpu
=
tpu
)
@
benchmark_wrappers
.
enable_runtime_flags
@
benchmark_wrappers
.
enable_runtime_flags
def
_run_and_report_benchmark
(
self
,
def
_run_and_report_benchmark
(
self
,
...
@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
...
@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA.
not converge to the 27.3 BLEU (uncased) SOTA.
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
1
FLAGS
.
num_gpus
=
1
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'base'
FLAGS
.
param_set
=
'base'
FLAGS
.
batch_size
=
2048
FLAGS
.
batch_size
=
2048
FLAGS
.
train_steps
=
1000
FLAGS
.
train_steps
=
1000
...
@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
...
@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA.
not converge to the 27.3 BLEU (uncased) SOTA.
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
1
FLAGS
.
num_gpus
=
1
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'base'
FLAGS
.
param_set
=
'base'
FLAGS
.
batch_size
=
4096
FLAGS
.
batch_size
=
4096
FLAGS
.
train_steps
=
100000
FLAGS
.
train_steps
=
100000
...
@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
...
@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'base'
FLAGS
.
param_set
=
'base'
FLAGS
.
batch_size
=
4096
*
8
FLAGS
.
batch_size
=
4096
*
8
FLAGS
.
train_steps
=
100000
FLAGS
.
train_steps
=
100000
...
@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
...
@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'base'
FLAGS
.
param_set
=
'base'
FLAGS
.
batch_size
=
4096
*
8
FLAGS
.
batch_size
=
4096
*
8
FLAGS
.
train_steps
=
100000
FLAGS
.
train_steps
=
100000
...
@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Iterations are not epochs, an iteration is a number of steps between evals.
Iterations are not epochs, an iteration is a number of steps between evals.
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
train_steps
=
20000
*
12
FLAGS
.
train_steps
=
20000
*
12
...
@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
static_batch
=
True
FLAGS
.
static_batch
=
True
...
@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
not epochs, an iteration is a number of steps between evals.
not epochs, an iteration is a number of steps between evals.
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
dtype
=
'fp16'
FLAGS
.
dtype
=
'fp16'
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
train_steps
=
20000
*
12
FLAGS
.
train_steps
=
20000
*
12
...
@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
dtype
=
'fp16'
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
train_steps
=
20000
*
12
FLAGS
.
train_steps
=
20000
*
12
...
@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
dtype
=
'fp16'
FLAGS
.
dtype
=
'fp16'
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
static_batch
=
True
FLAGS
.
static_batch
=
True
...
@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
"""
self
.
_setup
()
self
.
_setup
()
self
.
_set_data_file_flags
()
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
dtype
=
'fp16'
FLAGS
.
dtype
=
'fp16'
FLAGS
.
enable_xla
=
True
FLAGS
.
enable_xla
=
True
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
vocab_file
=
self
.
vocab_file
# Sets values directly to avoid validation check.
FLAGS
[
'bleu_source'
].
value
=
self
.
bleu_source
FLAGS
[
'bleu_ref'
].
value
=
self
.
bleu_ref
FLAGS
.
param_set
=
'big'
FLAGS
.
param_set
=
'big'
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
batch_size
=
3072
*
8
FLAGS
.
static_batch
=
True
FLAGS
.
static_batch
=
True
...
@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
...
@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
root_data_dir
=
root_data_dir
,
batch_per_gpu
=
3072
,
root_data_dir
=
root_data_dir
,
batch_per_gpu
=
3072
,
tpu
=
tpu
)
tpu
=
tpu
)
def
benchmark_2x2_tpu
(
self
):
def
_set_df_common
(
self
):
"""Port of former snaggletooth transformer_big model on 2x2."""
self
.
_set_data_files
(
tpu_run
=
True
)
self
.
_setup
()
FLAGS
.
data_dir
=
self
.
train_data_dir
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu'
)
FLAGS
.
vocab_file
=
self
.
vocab_file
FLAGS
.
distribution_strategy
=
'tpu'
FLAGS
.
padded_decode
=
True
FLAGS
.
train_steps
=
300
FLAGS
.
train_steps
=
300
FLAGS
.
log_steps
=
150
FLAGS
.
log_steps
=
150
FLAGS
.
steps_between_evals
=
150
FLAGS
.
steps_between_evals
=
150
FLAGS
.
distribution_strategy
=
'tpu'
FLAGS
.
static_batch
=
True
FLAGS
.
static_batch
=
True
FLAGS
.
use_ctl
=
True
FLAGS
.
use_ctl
=
True
FLAGS
.
batch_size
=
6144
FLAGS
.
enable_checkpointing
=
False
FLAGS
.
max_length
=
64
FLAGS
.
max_length
=
64
FLAGS
.
decode_batch_size
=
32
FLAGS
.
decode_batch_size
=
32
FLAGS
.
decode_max_length
=
97
FLAGS
.
decode_max_length
=
97
FLAGS
.
padded_decode
=
True
FLAGS
.
enable_checkpointing
=
False
def
benchmark_2x2_tpu
(
self
):
"""Port of former snaggletooth transformer_big model on 2x2."""
self
.
_setup
()
self
.
_set_df_common
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu'
)
FLAGS
.
batch_size
=
6144
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
)
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
def
benchmark_2x2_tpu_mlir
(
self
):
"""Run transformer_big model on 2x2 with the MLIR Bridge enabled."""
self
.
_setup
()
self
.
_set_df_common
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu_mlir'
)
FLAGS
.
batch_size
=
6144
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
(
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
total_batch_size
=
FLAGS
.
batch_size
,
...
@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
...
@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def
benchmark_4x4_tpu
(
self
):
def
benchmark_4x4_tpu
(
self
):
"""Port of former GCP transformer_big model on 4x4."""
"""Port of former GCP transformer_big model on 4x4."""
self
.
_setup
()
self
.
_setup
()
self
.
_set_df_common
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu'
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu'
)
FLAGS
.
train_steps
=
300
FLAGS
.
log_steps
=
150
FLAGS
.
steps_between_evals
=
150
FLAGS
.
distribution_strategy
=
'tpu'
FLAGS
.
static_batch
=
True
FLAGS
.
use_ctl
=
True
FLAGS
.
batch_size
=
24576
FLAGS
.
batch_size
=
24576
FLAGS
.
max_length
=
64
FLAGS
.
decode_batch_size
=
32
FLAGS
.
decode_max_length
=
97
FLAGS
.
padded_decode
=
True
FLAGS
.
enable_checkpointing
=
False
self
.
_run_and_report_benchmark
(
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
total_batch_size
=
FLAGS
.
batch_size
,
...
@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
...
@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def
benchmark_4x4_tpu_mlir
(
self
):
def
benchmark_4x4_tpu_mlir
(
self
):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self
.
_setup
()
self
.
_setup
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu'
)
self
.
_set_df_common
()
FLAGS
.
train_steps
=
300
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu_mlir'
)
FLAGS
.
log_steps
=
150
FLAGS
.
steps_between_evals
=
150
FLAGS
.
distribution_strategy
=
'tpu'
FLAGS
.
static_batch
=
True
FLAGS
.
use_ctl
=
True
FLAGS
.
batch_size
=
24576
FLAGS
.
batch_size
=
24576
FLAGS
.
max_length
=
64
FLAGS
.
decode_batch_size
=
32
FLAGS
.
decode_max_length
=
97
FLAGS
.
padded_decode
=
True
FLAGS
.
enable_checkpointing
=
False
tf
.
config
.
experimental
.
enable_mlir_bridge
()
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
(
self
.
_run_and_report_benchmark
(
...
...
official/colab/fine_tuning_bert.ipynb
View file @
afd5579f
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"cellView": "form",
"cellView": "form",
"colab": {},
"colab": {},
...
@@ -104,7 +104,7 @@
...
@@ -104,7 +104,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -128,7 +128,7 @@
...
@@ -128,7 +128,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -185,7 +185,7 @@
...
@@ -185,7 +185,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -204,12 +204,12 @@
...
@@ -204,12 +204,12 @@
"id": "9uFskufsR2LT"
"id": "9uFskufsR2LT"
},
},
"source": [
"source": [
"You can get a pre-trained BERT encoder from TensorFlow Hub
here
:"
"You can get a pre-trained BERT encoder from
[
TensorFlow Hub
](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2)
:"
]
]
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -252,7 +252,7 @@
...
@@ -252,7 +252,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -267,7 +267,7 @@
...
@@ -267,7 +267,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -290,7 +290,7 @@
...
@@ -290,7 +290,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -313,7 +313,7 @@
...
@@ -313,7 +313,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -336,7 +336,7 @@
...
@@ -336,7 +336,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -376,7 +376,7 @@
...
@@ -376,7 +376,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -404,7 +404,7 @@
...
@@ -404,7 +404,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -446,7 +446,7 @@
...
@@ -446,7 +446,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -469,7 +469,7 @@
...
@@ -469,7 +469,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -490,7 +490,7 @@
...
@@ -490,7 +490,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -514,7 +514,7 @@
...
@@ -514,7 +514,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -562,7 +562,7 @@
...
@@ -562,7 +562,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -587,7 +587,7 @@
...
@@ -587,7 +587,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -617,7 +617,7 @@
...
@@ -617,7 +617,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -661,7 +661,7 @@
...
@@ -661,7 +661,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -691,7 +691,7 @@
...
@@ -691,7 +691,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -737,7 +737,7 @@
...
@@ -737,7 +737,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -769,7 +769,7 @@
...
@@ -769,7 +769,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -793,7 +793,7 @@
...
@@ -793,7 +793,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -816,7 +816,7 @@
...
@@ -816,7 +816,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -845,7 +845,7 @@
...
@@ -845,7 +845,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -870,7 +870,7 @@
...
@@ -870,7 +870,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -908,7 +908,7 @@
...
@@ -908,7 +908,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -943,7 +943,7 @@
...
@@ -943,7 +943,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -986,7 +986,7 @@
...
@@ -986,7 +986,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1023,7 +1023,7 @@
...
@@ -1023,7 +1023,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1055,7 +1055,7 @@
...
@@ -1055,7 +1055,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1071,7 +1071,7 @@
...
@@ -1071,7 +1071,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1096,7 +1096,7 @@
...
@@ -1096,7 +1096,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1110,7 +1110,7 @@
...
@@ -1110,7 +1110,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1176,7 +1176,7 @@
...
@@ -1176,7 +1176,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1201,7 +1201,7 @@
...
@@ -1201,7 +1201,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1240,7 +1240,7 @@
...
@@ -1240,7 +1240,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1273,7 +1273,7 @@
...
@@ -1273,7 +1273,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1306,7 +1306,7 @@
...
@@ -1306,7 +1306,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1351,7 +1351,7 @@
...
@@ -1351,7 +1351,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1379,7 +1379,7 @@
...
@@ -1379,7 +1379,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1406,17 +1406,44 @@
...
@@ -1406,17 +1406,44 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
"id": "
lo6479At4sP1
"
"id": "
GDWrHm0BGpbX
"
},
},
"outputs": [],
"outputs": [],
"source": [
"source": [
"# Note: 350MB download.\n",
"# Note: 350MB download.\n",
"import tensorflow_hub as hub\n",
"import tensorflow_hub as hub"
"hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n",
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5"
},
"outputs": [],
"source": [
"hub_model_name = \"bert_en_uncased_L-12_H-768_A-12\" #@param [\"bert_en_uncased_L-24_H-1024_A-16\", \"bert_en_wwm_cased_L-24_H-1024_A-16\", \"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_wwm_uncased_L-24_H-1024_A-16\", \"bert_en_cased_L-24_H-1024_A-16\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_zh_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}\",\n",
" trainable=True)\n",
"\n",
"\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
]
]
...
@@ -1433,7 +1460,7 @@
...
@@ -1433,7 +1460,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1466,7 +1493,7 @@
...
@@ -1466,7 +1493,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1491,7 +1518,7 @@
...
@@ -1491,7 +1518,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1504,7 +1531,7 @@
...
@@ -1504,7 +1531,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1545,7 +1572,7 @@
...
@@ -1545,7 +1572,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1569,7 +1596,7 @@
...
@@ -1569,7 +1596,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1592,7 +1619,7 @@
...
@@ -1592,7 +1619,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1617,7 +1644,7 @@
...
@@ -1617,7 +1644,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1643,7 +1670,7 @@
...
@@ -1643,7 +1670,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1661,7 +1688,7 @@
...
@@ -1661,7 +1688,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1688,7 +1715,7 @@
...
@@ -1688,7 +1715,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1714,7 +1741,7 @@
...
@@ -1714,7 +1741,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1733,7 +1760,7 @@
...
@@ -1733,7 +1760,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1761,7 +1788,7 @@
...
@@ -1761,7 +1788,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
@@ -1795,7 +1822,7 @@
...
@@ -1795,7 +1822,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
0
,
"execution_count":
null
,
"metadata": {
"metadata": {
"colab": {},
"colab": {},
"colab_type": "code",
"colab_type": "code",
...
...
official/nlp/bert/model_saving_utils.py
View file @
afd5579f
...
@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
...
@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise
ValueError
(
'model must be a tf.keras.Model object.'
)
raise
ValueError
(
'model must be a tf.keras.Model object.'
)
if
checkpoint_dir
:
if
checkpoint_dir
:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if
restore_model_using_load_weights
:
if
restore_model_using_load_weights
:
model_weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
'checkpoint'
)
model_weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
'checkpoint'
)
assert
tf
.
io
.
gfile
.
exists
(
model_weight_path
)
assert
tf
.
io
.
gfile
.
exists
(
model_weight_path
)
model
.
load_weights
(
model_weight_path
)
model
.
load_weights
(
model_weight_path
)
# tf.train.Checkpoint API was used via custom training loop logic.
else
:
else
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
...
...
official/nlp/data/classifier_data_lib.py
View file @
afd5579f
...
@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor):
...
@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor):
return
"COLA"
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
# Only the test set has a header
# Only the test set has a header
.
if
set_type
==
"test"
and
i
==
0
:
if
set_type
==
"test"
and
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor):
...
@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor):
class
MnliProcessor
(
DataProcessor
):
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
"""Processor for the MultiNLI data set (GLUE version)."""
def
__init__
(
self
,
mnli_type
=
"matched"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
MnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
mnli_type
not
in
(
"matched"
,
"mismatched"
):
raise
ValueError
(
"Invalid `mnli_type`: %s"
%
mnli_type
)
self
.
mnli_type
=
mnli_type
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor):
...
@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
"dev_matched"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_mismatched"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_mismatched.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor):
...
@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor):
return
"MNLI"
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
...
@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor):
...
@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor):
return
"MRPC"
return
"MRPC"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor):
...
@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
language
,
train_tsv
))[
1
:])
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
language
,
train_tsv
))[
1
:])
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor):
...
@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"dev_2k.tsv"
))[
1
:])
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"dev_2k.tsv"
))[
1
:])
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor):
...
@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"test_2k.tsv"
))[
1
:]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"test_2k.tsv"
))[
1
:]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor):
...
@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor):
return
"QNLI"
return
"QNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
1
)
guid
=
"%s-%s"
%
(
set_type
,
1
)
...
@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor):
...
@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor):
return
"QQP"
return
"QQP"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
...
@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor):
...
@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor):
return
"RTE"
return
"RTE"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
...
@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor):
...
@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor):
return
"SST-2"
return
"SST-2"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor):
...
@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor):
return
"STS-B"
return
"STS-B"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
...
@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor):
...
@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return
"TFDS_"
+
self
.
dataset_name
return
"TFDS_"
+
self
.
dataset_name
def
_create_examples
(
self
,
split_name
,
set_type
):
def
_create_examples
(
self
,
split_name
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
if
split_name
not
in
self
.
dataset
:
if
split_name
not
in
self
.
dataset
:
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
...
@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor):
...
@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor):
return
"WNLI"
return
"WNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
...
@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor):
...
@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor):
"multinli.train.%s.tsv"
%
language
))[
1
:])
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor):
...
@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"dev-%d"
%
i
guid
=
"dev-%d"
%
i
...
@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor):
...
@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"test-%d"
%
i
guid
=
"test-%d"
%
i
...
@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor):
...
@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor):
...
@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor):
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor):
...
@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor):
...
@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor):
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor):
...
@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor):
...
@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
guid
=
f
"test-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
...
@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
for
ex_index
,
example
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
if
ex_index
%
10000
==
0
:
logging
.
info
(
"Writing example %d of %d"
,
ex_index
,
len
(
examples
))
logging
.
info
(
"Writing example %d of %d"
,
ex_index
,
len
(
examples
))
...
...
official/nlp/data/create_finetuning_data.py
View file @
afd5579f
...
@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
...
@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
"only and for XNLI is all languages combined. Same for "
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
"PAWS-X."
)
# XNLI task specific flag.
# MNLI task-specific flag.
flags
.
DEFINE_enum
(
"mnli_type"
,
"matched"
,
[
"matched"
,
"mismatched"
],
"The type of MNLI dataset."
)
# XNLI task-specific flag.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"xnli_language"
,
"en"
,
"xnli_language"
,
"en"
,
"Language of training data for XN
I
L task. If the value is 'all', the data "
"Language of training data for XNL
I
task. If the value is 'all', the data "
"of all languages will be used for training."
)
"of all languages will be used for training."
)
# PAWS-X task
specific flag.
# PAWS-X task
-
specific flag.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"pawsx_language"
,
"en"
,
"pawsx_language"
,
"en"
,
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of traini
n
g data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training."
)
"of all languages will be used for training."
)
# Retrieva task
specific flags
# Retrieva
l
task
-
specific flags
.
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
"The name of sentence retrieval task for scoring"
)
# Tagging task
specific flags
# Tagging task
-
specific flags
.
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
"The name of BERT tagging (token classification) task."
)
"The name of BERT tagging (token classification) task."
)
# BERT Squad task
specific flags.
# BERT Squad task
-
specific flags.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"squad_data_file"
,
None
,
"squad_data_file"
,
None
,
"The input data file in for generating training data for BERT squad task."
)
"The input data file in for generating training data for BERT squad task."
)
...
@@ -179,7 +184,8 @@ def generate_classifier_dataset():
...
@@ -179,7 +184,8 @@ def generate_classifier_dataset():
"cola"
:
"cola"
:
classifier_data_lib
.
ColaProcessor
,
classifier_data_lib
.
ColaProcessor
,
"mnli"
:
"mnli"
:
classifier_data_lib
.
MnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
MnliProcessor
,
mnli_type
=
FLAGS
.
mnli_type
),
"mrpc"
:
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
"qnli"
:
...
...
official/nlp/modeling/layers/attention.py
View file @
afd5579f
...
@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
...
@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX
=
string
.
ascii_lowercase
_CHR_IDX
=
string
.
ascii_lowercase
def
_build_attention_equation
(
qkv_
rank
,
attn_axes
):
def
_build_attention_equation
(
rank
,
attn_axes
):
"""Builds einsum equations for the attention computation.
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
Query, key, value inputs after projection are expected to have the shape as:
...
@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
...
@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels)
<query attention dims>, num_heads, channels)
Args:
Args:
qkv_
rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns:
Returns:
Einsum equations.
Einsum equations.
"""
"""
target_notation
=
_CHR_IDX
[:
qkv_
rank
]
target_notation
=
_CHR_IDX
[:
rank
]
# `batch_dims` includes the head dim.
# `batch_dims` includes the head dim.
batch_dims
=
tuple
(
np
.
delete
(
range
(
qkv_
rank
),
attn_axes
+
(
qkv_
rank
-
1
,)))
batch_dims
=
tuple
(
np
.
delete
(
range
(
rank
),
attn_axes
+
(
rank
-
1
,)))
letter_offset
=
qkv_
rank
letter_offset
=
rank
source_notation
=
""
source_notation
=
""
for
i
in
range
(
qkv_
rank
):
for
i
in
range
(
rank
):
if
i
in
batch_dims
or
i
==
qkv_
rank
-
1
:
if
i
in
batch_dims
or
i
==
rank
-
1
:
source_notation
+=
target_notation
[
i
]
source_notation
+=
target_notation
[
i
]
else
:
else
:
source_notation
+=
_CHR_IDX
[
letter_offset
]
source_notation
+=
_CHR_IDX
[
letter_offset
]
...
@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim.
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head
return_attention_scores: bool, if `True`, returns the multi-head
attention
attention
scores as an additional output argument.
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
kernel_regularizer: Regularizer for dense layer kernels.
...
@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self
.
_attention_axes
=
(
attention_axes
,)
self
.
_attention_axes
=
(
attention_axes
,)
else
:
else
:
self
.
_attention_axes
=
attention_axes
self
.
_attention_axes
=
attention_axes
self
.
_built_from_signature
=
False
def
get_config
(
self
):
def
get_config
(
self
):
config
=
{
config
=
{
...
@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config
=
super
(
MultiHeadAttention
,
self
).
get_config
()
base_config
=
super
(
MultiHeadAttention
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
inputs_len
=
len
(
input_shape
)
"""Builds layers and variables.
if
inputs_len
>
3
or
inputs_len
<
2
:
raise
ValueError
(
Once the method is called, self._built_from_signature will be set to True.
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
Args:
"Given length: %d"
%
inputs_len
)
query: query tensor or TensorShape.
tensor_shapes
=
tf
.
nest
.
map_structure
(
tf
.
TensorShape
,
input_shape
)
value: value tensor or TensorShape.
query_shape
=
tensor_shapes
[
0
]
key: key tensor or TensorShape.
value_shape
=
tensor_shapes
[
1
]
"""
key_shape
=
tensor_shapes
[
2
]
if
inputs_len
==
3
else
value_shape
self
.
_built_from_signature
=
True
if
hasattr
(
query
,
"shape"
):
query_shape
=
tf
.
TensorShape
(
query
.
shape
)
else
:
query_shape
=
query
if
hasattr
(
value
,
"shape"
):
value_shape
=
tf
.
TensorShape
(
value
.
shape
)
else
:
value_shape
=
value
if
key
is
None
:
key_shape
=
value_shape
elif
hasattr
(
key
,
"shape"
):
key_shape
=
tf
.
TensorShape
(
key
.
shape
)
else
:
key_shape
=
key
common_kwargs
=
dict
(
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
...
@@ -271,7 +293,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -271,7 +293,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
bias_constraint
=
self
.
_bias_constraint
)
with
tf
.
init_scope
():
free_dims
=
query_shape
.
rank
-
1
free_dims
=
query_shape
.
rank
-
1
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
1
,
output_dims
=
2
)
free_dims
,
bound_dims
=
1
,
output_dims
=
2
)
...
@@ -302,9 +324,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -302,9 +324,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
**
common_kwargs
)
**
common_kwargs
)
# Builds the attention computations for multi-head dot product attention.
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
it
# These computations could be wrapped into the keras attention layer once
#
support mult-head einsum computations.
# it
support mult-head einsum computations.
self
.
_
build_attention
(
output_rank
)
self
.
build_attention
(
output_rank
)
if
self
.
_output_shape
:
if
self
.
_output_shape
:
if
not
isinstance
(
self
.
_output_shape
,
collections
.
abc
.
Sized
):
if
not
isinstance
(
self
.
_output_shape
,
collections
.
abc
.
Sized
):
output_shape
=
[
self
.
_output_shape
]
output_shape
=
[
self
.
_output_shape
]
...
@@ -320,35 +342,30 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -320,35 +342,30 @@ class MultiHeadAttention(tf.keras.layers.Layer):
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"attention_output"
,
name
=
"attention_output"
,
**
common_kwargs
)
**
common_kwargs
)
super
(
MultiHeadAttention
,
self
).
build
(
input_shape
)
def
_
build_attention
(
self
,
qkv_
rank
):
def
build_attention
(
self
,
rank
):
"""Builds multi-head dot-product attention computations.
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `
_
compute_attention` to
This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product
costomize attention computation to replace the default dot-product
attention.
attention.
Args:
Args:
qkv_
rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
"""
"""
if
self
.
_attention_axes
is
None
:
if
self
.
_attention_axes
is
None
:
self
.
_attention_axes
=
tuple
(
range
(
1
,
qkv_
rank
-
2
))
self
.
_attention_axes
=
tuple
(
range
(
1
,
rank
-
2
))
else
:
else
:
self
.
_attention_axes
=
tuple
(
self
.
_attention_axes
)
self
.
_attention_axes
=
tuple
(
self
.
_attention_axes
)
self
.
_dot_product_equation
,
self
.
_combine_equation
,
attn_scores_rank
=
(
self
.
_dot_product_equation
,
self
.
_combine_equation
,
attn_scores_rank
=
(
_build_attention_equation
(
qkv_
rank
,
attn_axes
=
self
.
_attention_axes
))
_build_attention_equation
(
rank
,
attn_axes
=
self
.
_attention_axes
))
norm_axes
=
tuple
(
norm_axes
=
tuple
(
range
(
attn_scores_rank
-
len
(
self
.
_attention_axes
),
attn_scores_rank
))
range
(
attn_scores_rank
-
len
(
self
.
_attention_axes
),
attn_scores_rank
))
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
],
normalization_axes
=
norm_axes
)
mask_expansion_axes
=
[
1
],
normalization_axes
=
norm_axes
)
self
.
_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
self
.
_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
_compute_attention
(
self
,
def
compute_attention
(
self
,
query
,
key
,
value
,
attention_mask
=
None
):
query_tensor
,
key_tensor
,
value_tensor
,
attention_mask
=
None
):
"""Applies Dot-product attention with query, key, value tensors.
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
This function defines the computation inside `call` with projected
...
@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation.
attention implementation.
Args:
Args:
query
_tensor
: Projected query `Tensor` of shape `[B, T, N, key_size]`.
query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key
_tensor
: Projected key `Tensor` of shape `[B, T, N, key_size]`.
key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value
_tensor
: Projected value `Tensor` of shape `[B, T, N, value_size]`.
value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
attention to certain positions.
...
@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_output: Multi-headed outputs of attention computation.
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
attention_scores: Multi-headed attention weights.
"""
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
tf
.
multiply
(
query
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
# Take the dot product between "query" and "key" to get the raw
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# attention scores.
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key_tensor
,
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
)
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
# Normalize the attention scores to probabilities.
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
# `attention_scores` = [B, N, T, S]
...
@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H]
# `context_layer` = [B, T, N, H]
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_scores_dropout
,
value
_tensor
)
attention_scores_dropout
,
value
)
return
attention_output
,
attention_scores
return
attention_output
,
attention_scores
def
call
(
self
,
inputs
,
attention_mask
=
None
):
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
):
"""Implements the forward pass.
"""Implements the forward pass.
Size glossary:
Size glossary:
...
@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target.
* Value (source) attention axes shape (S), the rank must match the target.
Args:
Args:
inputs: List of the following tensors:
query: Query `Tensor` of shape `[B, T, dim]`.
* query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
`value` for both `key` and `value`, which is the most common case.
use `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
attention to certain positions.
...
@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention
attention
axes.
axes.
"""
"""
inputs_len
=
len
(
inputs
)
if
not
self
.
_built_from_signature
:
if
inputs_len
>
3
or
inputs_len
<
2
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
raise
ValueError
(
if
key
is
None
:
"Expects inputs list of length 2 or 3, namely [query, value] or "
key
=
value
"[query, value, key]. "
"Given length: %d"
%
inputs_len
)
query
=
inputs
[
0
]
value
=
inputs
[
1
]
key
=
inputs
[
2
]
if
inputs_len
==
3
else
value
# N = `num_attention_heads`
# N = `num_attention_heads`
# H = `size_per_head`
# H = `size_per_head`
# `query
_tensor
` = [B, T, N ,H]
# `query` = [B, T, N ,H]
query
_tensor
=
self
.
_query_dense
(
query
)
query
=
self
.
_query_dense
(
query
)
# `key
_tensor
` = [B, S, N, H]
# `key` = [B, S, N, H]
key
_tensor
=
self
.
_key_dense
(
key
)
key
=
self
.
_key_dense
(
key
)
# `value
_tensor
` = [B, S, N, H]
# `value` = [B, S, N, H]
value
_tensor
=
self
.
_value_dense
(
value
)
value
=
self
.
_value_dense
(
value
)
attention_output
,
attention_scores
=
self
.
_
compute_attention
(
attention_output
,
attention_scores
=
self
.
compute_attention
(
query
_tensor
,
key_tensor
,
value_tensor
,
attention_mask
)
query
,
key
,
value
,
attention_mask
)
attention_output
=
self
.
_output_dense
(
attention_output
)
attention_output
=
self
.
_output_dense
(
attention_output
)
if
self
.
_return_attention_scores
:
if
self
.
_return_attention_scores
:
...
@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
...
@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer.
Arguments are the same as `MultiHeadAttention` layer.
"""
"""
def
_update_cache
(
self
,
key
_tensor
,
value_tensor
,
cache
,
decode_loop_step
):
def
_update_cache
(
self
,
key
,
value
,
cache
,
decode_loop_step
):
"""Updates cache states and gets full-length key/value tensors."""
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
# Combines cached keys and values with new keys and values.
if
decode_loop_step
is
not
None
:
if
decode_loop_step
is
not
None
:
# TPU special case.
# TPU special case.
key_seq_dim
=
cache
[
"key"
].
shape
.
as_list
()[
1
]
key_seq_dim
=
cache
[
"key"
].
shape
.
as_list
()[
1
]
indices
=
tf
.
reshape
(
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
key_seq_dim
,
dtype
=
key
_tensor
.
dtype
),
tf
.
one_hot
(
decode_loop_step
,
key_seq_dim
,
dtype
=
key
.
dtype
),
[
1
,
key_seq_dim
,
1
,
1
])
[
1
,
key_seq_dim
,
1
,
1
])
key
_tensor
=
cache
[
"key"
]
+
key
_tensor
*
indices
key
=
cache
[
"key"
]
+
key
*
indices
value_seq_dim
=
cache
[
"value"
].
shape
.
as_list
()[
1
]
value_seq_dim
=
cache
[
"value"
].
shape
.
as_list
()[
1
]
indices
=
tf
.
reshape
(
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
value_seq_dim
,
dtype
=
value
_tensor
.
dtype
),
tf
.
one_hot
(
decode_loop_step
,
value_seq_dim
,
dtype
=
value
.
dtype
),
[
1
,
value_seq_dim
,
1
,
1
])
[
1
,
value_seq_dim
,
1
,
1
])
value
_tensor
=
cache
[
"value"
]
+
value
_tensor
*
indices
value
=
cache
[
"value"
]
+
value
*
indices
else
:
else
:
key_tensor
=
tf
.
concat
(
key
=
tf
.
concat
([
tf
.
cast
(
cache
[
"key"
],
key
.
dtype
),
key
],
axis
=
1
)
[
tf
.
cast
(
cache
[
"key"
],
key_tensor
.
dtype
),
key_tensor
],
axis
=
1
)
value
=
tf
.
concat
([
tf
.
cast
(
cache
[
"value"
],
value
.
dtype
),
value
],
axis
=
1
)
value_tensor
=
tf
.
concat
(
[
tf
.
cast
(
cache
[
"value"
],
value_tensor
.
dtype
),
value_tensor
],
axis
=
1
)
# Update cache
# Update cache
cache
[
"key"
]
=
key
_tensor
cache
[
"key"
]
=
key
cache
[
"value"
]
=
value
_tensor
cache
[
"value"
]
=
value
return
key
_tensor
,
value_tensor
return
key
,
value
def
call
(
self
,
def
call
(
self
,
inputs
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
cache
=
None
,
cache
=
None
,
decode_loop_step
=
None
):
decode_loop_step
=
None
):
from_tensor
=
inputs
[
0
]
if
not
self
.
_built_from_signature
:
to_tensor
=
inputs
[
1
]
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# Scalar dimensions referenced here:
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# B = batch size (number of sequences)
...
@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention):
...
@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# N = `num_attention_heads`
# H = `size_per_head`
# H = `size_per_head`
# `query
_tensor
` = [B, F, N ,H]
# `query` = [B, F, N ,H]
query
_tensor
=
self
.
_query_dense
(
from_tensor
)
query
=
self
.
_query_dense
(
query
)
# `key
_tensor
` = [B, T, N, H]
# `key` = [B, T, N, H]
key
_tensor
=
self
.
_key_dense
(
to_tensor
)
key
=
self
.
_key_dense
(
key
)
# `value
_tensor
` = [B, T, N, H]
# `value` = [B, T, N, H]
value
_tensor
=
self
.
_value_dense
(
to_tensor
)
value
=
self
.
_value_dense
(
value
)
if
cache
:
if
cache
:
key_tensor
,
value_tensor
=
self
.
_update_cache
(
key_tensor
,
value_tensor
,
key
,
value
=
self
.
_update_cache
(
key
,
value
,
cache
,
decode_loop_step
)
cache
,
decode_loop_step
)
# Take the dot product between "query" and "key" to get the raw
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# attention scores.
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key_tensor
,
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
)
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
...
@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
...
@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores
=
self
.
_dropout_layer
(
attention_scores
)
attention_scores
=
self
.
_dropout_layer
(
attention_scores
)
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, F, N, H]
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_scores
,
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_scores
,
value
_tensor
)
value
)
attention_output
=
self
.
_output_dense
(
attention_output
)
attention_output
=
self
.
_output_dense
(
attention_output
)
if
self
.
_return_attention_scores
:
if
self
.
_return_attention_scores
:
return
attention_output
,
attention_scores
,
cache
return
attention_output
,
attention_scores
,
cache
...
...
official/nlp/modeling/layers/attention_test.py
View file @
afd5579f
...
@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
(
[
query
,
value
]
)
output
=
test_layer
(
query
=
query
,
value
=
value
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
def
test_non_masked_self_attention
(
self
):
def
test_non_masked_self_attention
(
self
):
...
@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
12
,
key_size
=
64
)
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
12
,
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
,
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_attention_scores
(
self
):
def
test_attention_scores
(
self
):
...
@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
,
coef
=
test_layer
(
[
query
,
query
]
)
output
,
coef
=
test_layer
(
query
,
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
...
@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
(
[
query
,
value
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_tensor
)
# Create a model containing the test layer.
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
...
@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
# Tests the layer with three inputs: Q, K, V.
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
output
=
test_layer
(
[
query
,
value
,
key
],
mask_tensor
)
output
=
test_layer
(
query
,
value
=
value
,
key
=
key
,
attention_mask
=
mask_tensor
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
...
@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
,
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
...
@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# Invoke the data with a random set of mask data. This should mask at least
# one element.
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
output
=
test_layer
(
[
query
,
value
],
mask_data
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_data
)
# Invoke the same data, but with a null mask (where no elements are masked).
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
(
mask_shape
)
null_mask_data
=
np
.
ones
(
mask_shape
)
unmasked_output
=
test_layer
([
query
,
value
],
null_mask_data
)
unmasked_output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
null_mask_data
)
# Because one data is masked and one is not, the outputs should not be the
# Because one data is masked and one is not, the outputs should not be the
# same.
# same.
self
.
assertNotAllClose
(
output
,
unmasked_output
)
self
.
assertNotAllClose
(
output
,
unmasked_output
)
...
@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
...
@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size
=
64
)
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
,
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
...
@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
...
@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element.
# one element.
mask_data
=
np
.
random
.
randint
(
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
))
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
))
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
],
mask_data
,
cache
)
masked_output_data
,
cache
=
layer
(
query
=
from_data
,
value
=
from_data
,
attention_mask
=
mask_data
,
cache
=
cache
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
# Tests inputs without cache.
# Tests inputs without cache.
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
])
masked_output_data
,
cache
=
layer
(
query
=
from_data
,
value
=
from_data
,
attention_mask
=
mask_data
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertIsNone
(
cache
)
self
.
assertIsNone
(
cache
)
...
@@ -243,9 +246,11 @@ class CachedAttentionTest(keras_parameterized.TestCase):
...
@@ -243,9 +246,11 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data
=
np
.
random
.
randint
(
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
),
dtype
=
np
.
int32
)
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
),
dtype
=
np
.
int32
)
# Testing the invocation directly as Keras cannot consume inputs correctly.
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
],
masked_output_data
,
cache
=
layer
(
mask_data
,
query
=
from_data
,
cache
,
value
=
from_data
,
attention_mask
=
mask_data
,
cache
=
cache
,
decode_loop_step
=
decode_loop_step
)
decode_loop_step
=
decode_loop_step
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
...
...
official/nlp/modeling/layers/multi_channel_attention.py
View file @
afd5579f
...
@@ -110,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
...
@@ -110,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
class
MultiChannelAttention
(
attention
.
MultiHeadAttention
):
class
MultiChannelAttention
(
attention
.
MultiHeadAttention
):
"""Multi-channel Attention layer.
"""Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple
Introduced in, [Generating Representative Headlines for News Stories
cross-attention target sequences.
](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
"""
def
_
build_attention
(
self
,
qkv_
rank
):
def
build_attention
(
self
,
rank
):
super
(
MultiChannelAttention
,
self
).
_
build_attention
(
qkv_
rank
)
super
(
MultiChannelAttention
,
self
).
build_attention
(
rank
)
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
def
call
(
self
,
inputs
,
attention_mask
=
None
):
def
call
(
self
,
from_tensor
=
inputs
[
0
]
query
,
to_tensor
=
inputs
[
1
]
value
,
doc_attention_probs
=
inputs
[
2
]
key
=
None
,
context_attention_weights
=
None
,
attention_mask
=
None
):
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
,
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# Scalar dimensions referenced here:
# Scalar dimensions referenced here:
# B = batch size (number of stories)
# B = batch size (number of stories)
# A = num_docs (number of docs)
# A = num_docs (number of docs)
# F =
`from_tensor`
sequence length
# F =
target
sequence length
# T =
`to_tensor`
sequence length
# T =
source
sequence length
# N = `num_attention_heads`
# N = `num_attention_heads`
# H = `size_per_head`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
# `query_tensor` = [B, F, N ,H]
query_tensor
=
self
.
_query_dense
(
from_tensor
)
query_tensor
=
self
.
_query_dense
(
query
)
# `key_tensor` = [B, A, T, N, H]
# `key_tensor` = [B, A, T, N, H]
key_tensor
=
self
.
_key_dense
(
to_tensor
)
key_tensor
=
self
.
_key_dense
(
key
)
# `value_tensor` = [B, A, T, N, H]
# `value_tensor` = [B, A, T, N, H]
value_tensor
=
self
.
_value_dense
(
to_tensor
)
value_tensor
=
self
.
_value_dense
(
value
)
# Take the dot product between "query" and "key" to get the raw
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# attention scores.
...
@@ -156,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
...
@@ -156,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, F, N, H]
context_layer
=
tf
.
einsum
(
"BANFT,BATNH->BAFNH"
,
attention_probs
,
context_layer
=
tf
.
einsum
(
"BANFT,BATNH->BAFNH"
,
attention_probs
,
value_tensor
)
value_tensor
)
attention_output
=
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
doc
_attention_
prob
s
,
attention_output
=
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
context
_attention_
weight
s
,
context_layer
)
context_layer
)
attention_output
=
self
.
_output_dense
(
attention_output
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
return
attention_output
official/nlp/modeling/layers/multi_channel_attention_test.py
View file @
afd5579f
...
@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
...
@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_docs
,
4
,
2
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_docs
,
4
,
2
))
doc_probs
=
np
.
random
.
randint
(
doc_probs
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_heads
,
4
,
num_docs
)).
astype
(
float
)
2
,
size
=
(
3
,
num_heads
,
4
,
num_docs
)).
astype
(
float
)
outputs
=
attention_layer
([
from_data
,
to_data
,
doc_probs
],
mask_data
)
outputs
=
attention_layer
(
query
=
from_data
,
value
=
to_data
,
context_attention_weights
=
doc_probs
,
attention_mask
=
mask_data
)
self
.
assertEqual
(
outputs
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
outputs
.
shape
,
(
3
,
4
,
8
))
...
...
official/nlp/modeling/layers/position_embedding.py
View file @
afd5579f
...
@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
...
@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size"
:
self
.
_hidden_size
,
"hidden_size"
:
self
.
_hidden_size
,
"min_timescale"
:
self
.
_min_timescale
,
"min_timescale"
:
self
.
_min_timescale
,
"max_timescale"
:
self
.
_max_timescale
,
"max_timescale"
:
self
.
_max_timescale
,
"length"
:
self
.
_length
,
}
}
base_config
=
super
(
RelativePositionEmbedding
,
self
).
get_config
()
base_config
=
super
(
RelativePositionEmbedding
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
...
official/nlp/modeling/layers/rezero_transformer.py
View file @
afd5579f
...
@@ -213,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -213,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
else
:
target_tensor
=
input_tensor
target_tensor
=
input_tensor
attention_inputs
=
[
target_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
...
...
official/nlp/modeling/layers/talking_heads_attention.py
View file @
afd5579f
...
@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
...
@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
"""
def
_
build_attention
(
self
,
qkv_rank
):
def
build_attention
(
self
,
qkv_rank
):
"""Builds multi-head dot-product attention computations.
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
This function overrides base class to create additional linear projection
...
@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
...
@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args:
Args:
qkv_rank: the rank of query, key, value tensors after projection.
qkv_rank: the rank of query, key, value tensors after projection.
"""
"""
super
(
TalkingHeadsAttention
,
self
).
_
build_attention
(
qkv_rank
)
super
(
TalkingHeadsAttention
,
self
).
build_attention
(
qkv_rank
)
# Build an equation:
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
...
@@ -103,7 +103,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
...
@@ -103,7 +103,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
trainable
=
True
)
trainable
=
True
)
def
_
compute_attention
(
self
,
def
compute_attention
(
self
,
query_tensor
,
query_tensor
,
key_tensor
,
key_tensor
,
value_tensor
,
value_tensor
,
...
...
official/nlp/modeling/layers/talking_heads_attention_test.py
View file @
afd5579f
...
@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
(
[
query
,
value
]
)
output
=
test_layer
(
query
=
query
,
value
=
value
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
def
test_non_masked_self_attention
(
self
):
def
test_non_masked_self_attention
(
self
):
...
@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
)
num_heads
=
12
,
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_attention_scores
(
self
):
def
test_attention_scores
(
self
):
...
@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
,
coef
=
test_layer
(
[
query
,
query
]
)
output
,
coef
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
...
@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
(
[
query
,
value
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_tensor
)
# Create a model containing the test layer.
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
...
@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
# Tests the layer with three inputs: Q, K, V.
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
output
=
test_layer
([
query
,
value
,
key
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
key
=
key
,
attention_mask
=
mask_tensor
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
...
@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
...
@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# Invoke the data with a random set of mask data. This should mask at least
# one element.
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
output
=
test_layer
(
[
query
,
value
],
mask_data
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_data
)
# Invoke the same data, but with a null mask (where no elements are masked).
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
(
mask_shape
)
null_mask_data
=
np
.
ones
(
mask_shape
)
unmasked_output
=
test_layer
([
query
,
value
],
null_mask_data
)
unmasked_output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
null_mask_data
)
# Because one data is masked and one is not, the outputs should not be the
# Because one data is masked and one is not, the outputs should not be the
# same.
# same.
self
.
assertNotAllClose
(
output
,
unmasked_output
)
self
.
assertNotAllClose
(
output
,
unmasked_output
)
...
...
official/nlp/modeling/layers/transformer.py
View file @
afd5579f
...
@@ -120,7 +120,9 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -120,7 +120,9 @@ class Transformer(tf.keras.layers.Layer):
name
=
"self_attention"
,
name
=
"self_attention"
,
**
common_kwargs
)
**
common_kwargs
)
# pylint: disable=protected-access
# pylint: disable=protected-access
self
.
_attention_layer
.
build
([
input_tensor_shape
]
*
3
)
# Temporarily handling for checkpoint compatible changes.
self
.
_attention_layer
.
_build_from_signature
(
query
=
input_tensor_shape
,
value
=
input_tensor_shape
)
self
.
_attention_output_dense
=
self
.
_attention_layer
.
_output_dense
self
.
_attention_output_dense
=
self
.
_attention_layer
.
_output_dense
# pylint: enable=protected-access
# pylint: enable=protected-access
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
...
@@ -202,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -202,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
else
:
target_tensor
=
input_tensor
target_tensor
=
input_tensor
attention_inputs
=
[
target_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
attention_output
)
...
@@ -382,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -382,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d"
%
"TransformerDecoderLayer must have 4 inputs, but it got: %d"
%
len
(
inputs
))
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
self_attention_inputs
=
[
input_tensor
,
input_tensor
]
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_inputs
,
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
self_attention_mask
,
attention_mask
=
self_attention_mask
,
cache
=
cache
,
cache
=
cache
,
decode_loop_step
=
decode_loop_step
)
decode_loop_step
=
decode_loop_step
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_layer_norm
(
self_attention_output
=
self
.
self_attention_layer_norm
(
input_tensor
+
self_attention_output
)
input_tensor
+
self_attention_output
)
cross_attn_inputs
=
dict
(
cross_attn_inputs
=
[
self_attention_output
,
memory
]
query
=
self_attention_output
,
value
=
memory
,
attention_mask
=
attention_mask
)
if
self
.
multi_channel_cross_attention
:
if
self
.
multi_channel_cross_attention
:
# Accesses the 5-th input tensor for the doc-attention probabilities.
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs
.
append
(
inputs
[
-
1
]
)
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
attention_output
=
self
.
encdec_attention
(
cross_attn_inputs
,
attention_mask
)
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
)
attention_output
)
...
...
official/nlp/modeling/layers/transformer_scaffold.py
View file @
afd5579f
...
@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else
:
else
:
input_tensor
,
attention_mask
=
(
inputs
,
None
)
input_tensor
,
attention_mask
=
(
inputs
,
None
)
attention_inputs
=
[
input_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
)
attention_output
)
...
...
official/nlp/modeling/layers/transformer_scaffold_test.py
View file @
afd5579f
...
@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
...
@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
self
.
list
=
call_list
def
call
(
self
,
inputs
,
attention_mask
=
None
):
def
call
(
self
,
query
,
value
,
attention_mask
=
None
):
self
.
list
.
append
(
True
)
self
.
list
.
append
(
True
)
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
inputs
,
attention_mask
=
attention_mask
)
query
,
value
,
attention_mask
=
attention_mask
)
def
get_config
(
self
):
def
get_config
(
self
):
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
...
...
official/nlp/modeling/layers/transformer_test.py
View file @
afd5579f
...
@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_
=
new_layer
([
input_data
,
mask_data
])
_
=
new_layer
([
input_data
,
mask_data
])
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.003
)
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
...
...
official/nlp/modeling/networks/encoder_scaffold_test.py
View file @
afd5579f
...
@@ -323,6 +323,28 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
...
@@ -323,6 +323,28 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
class
Embeddings
(
tf
.
keras
.
Model
):
def
__init__
(
self
,
vocab_size
,
hidden_size
):
super
().
__init__
()
self
.
inputs
=
[
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
),
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
"input_mask"
)
]
self
.
attention_mask
=
layers
.
SelfAttentionMask
()
self
.
embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
hidden_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
name
=
"word_embeddings"
)
def
call
(
self
,
inputs
):
word_ids
,
mask
=
inputs
word_embeddings
=
self
.
embedding_layer
(
word_ids
)
return
word_embeddings
,
self
.
attention_mask
([
word_embeddings
,
mask
])
@
keras_parameterized
.
run_all_keras_modes
@
keras_parameterized
.
run_all_keras_modes
class
EncoderScaffoldEmbeddingNetworkTest
(
keras_parameterized
.
TestCase
):
class
EncoderScaffoldEmbeddingNetworkTest
(
keras_parameterized
.
TestCase
):
...
@@ -334,20 +356,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
...
@@ -334,20 +356,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
# Build an embedding network to swap in for the default network. This one
# Build an embedding network to swap in for the default network. This one
# will have 2 inputs (mask and word_ids) instead of 3, and won't use
# will have 2 inputs (mask and word_ids) instead of 3, and won't use
# positional embeddings.
# positional embeddings.
network
=
Embeddings
(
vocab_size
,
hidden_size
)
word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
)
mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
,
name
=
"input_mask"
)
embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
hidden_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
name
=
"word_embeddings"
)
word_embeddings
=
embedding_layer
(
word_ids
)
attention_mask
=
layers
.
SelfAttentionMask
()([
word_embeddings
,
mask
])
network
=
tf
.
keras
.
Model
([
word_ids
,
mask
],
[
word_embeddings
,
attention_mask
])
hidden_cfg
=
{
hidden_cfg
=
{
"num_attention_heads"
:
"num_attention_heads"
:
...
@@ -371,8 +380,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
...
@@ -371,8 +380,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
stddev
=
0.02
),
hidden_cfg
=
hidden_cfg
,
hidden_cfg
=
hidden_cfg
,
embedding_cls
=
network
,
embedding_cls
=
network
)
embedding_data
=
embedding_layer
.
embeddings
)
# Create the inputs (note that the first dimension is implicit).
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
@@ -390,11 +398,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
...
@@ -390,11 +398,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
_
=
model
.
predict
([
word_id_data
,
mask_data
])
_
=
model
.
predict
([
word_id_data
,
mask_data
])
# Test that we can get the embedding data that we passed to the object. This
# is necessary to support standard language model training.
self
.
assertIs
(
embedding_layer
.
embeddings
,
test_network
.
get_embedding_table
())
def
test_serialize_deserialize
(
self
):
def
test_serialize_deserialize
(
self
):
hidden_size
=
32
hidden_size
=
32
sequence_length
=
21
sequence_length
=
21
...
...
official/nlp/modeling/ops/__init__.py
0 → 100644
View file @
afd5579f
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment