Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a35e09d2
Unverified
Commit
a35e09d2
authored
Aug 28, 2019
by
Vinh Nguyen
Committed by
GitHub
Aug 28, 2019
Browse files
Merge branch 'master' into amp_resnet50
parents
d5722dcd
1f5a5e9d
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
390 additions
and
196 deletions
+390
-196
official/r1/utils/export_test.py
official/r1/utils/export_test.py
+1
-1
official/r1/wide_deep/movielens_dataset.py
official/r1/wide_deep/movielens_dataset.py
+1
-1
official/recommendation/create_ncf_data.py
official/recommendation/create_ncf_data.py
+2
-2
official/recommendation/ncf_common.py
official/recommendation/ncf_common.py
+3
-1
official/recommendation/ncf_input_pipeline.py
official/recommendation/ncf_input_pipeline.py
+3
-3
official/recommendation/ncf_keras_benchmark.py
official/recommendation/ncf_keras_benchmark.py
+20
-1
official/recommendation/ncf_keras_main.py
official/recommendation/ncf_keras_main.py
+11
-0
official/resnet/ctl/ctl_common.py
official/resnet/ctl/ctl_common.py
+3
-0
official/resnet/ctl/ctl_imagenet_benchmark.py
official/resnet/ctl/ctl_imagenet_benchmark.py
+4
-3
official/resnet/ctl/ctl_imagenet_main.py
official/resnet/ctl/ctl_imagenet_main.py
+32
-16
official/resnet/ctl/ctl_imagenet_test.py
official/resnet/ctl/ctl_imagenet_test.py
+3
-3
official/resnet/keras/__init__.py
official/resnet/keras/__init__.py
+0
-40
official/staging/shakespeare/shakespeare_benchmark.py
official/staging/shakespeare/shakespeare_benchmark.py
+2
-2
official/transformer/model/beam_search.py
official/transformer/model/beam_search.py
+126
-35
official/transformer/transformer_main.py
official/transformer/transformer_main.py
+2
-2
official/transformer/v2/attention_layer.py
official/transformer/v2/attention_layer.py
+45
-28
official/transformer/v2/beam_search.py
official/transformer/v2/beam_search.py
+36
-21
official/transformer/v2/embedding_layer.py
official/transformer/v2/embedding_layer.py
+2
-13
official/transformer/v2/misc.py
official/transformer/v2/misc.py
+23
-0
official/transformer/v2/transformer.py
official/transformer/v2/transformer.py
+71
-24
No files found.
official/utils/export
/export
_test.py
→
official/
r1/
utils/export_test.py
View file @
a35e09d2
...
@@ -20,7 +20,7 @@ from __future__ import print_function
...
@@ -20,7 +20,7 @@ from __future__ import print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils
.export
import
export
from
official.
r1.
utils
import
export
class
ExportUtilsTest
(
tf
.
test
.
TestCase
):
class
ExportUtilsTest
(
tf
.
test
.
TestCase
):
...
...
official/r1/wide_deep/movielens_dataset.py
View file @
a35e09d2
...
@@ -29,7 +29,7 @@ import tensorflow as tf
...
@@ -29,7 +29,7 @@ import tensorflow as tf
# pylint: enable=wrong-import-order
# pylint: enable=wrong-import-order
from
official.datasets
import
movielens
from
official.datasets
import
movielens
from
official.utils.data
import
file_io
from
official.
r1.
utils.data
import
file_io
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
...
...
official/recommendation/create_ncf_data.py
View file @
a35e09d2
...
@@ -65,8 +65,8 @@ def prepare_raw_data(flag_obj):
...
@@ -65,8 +65,8 @@ def prepare_raw_data(flag_obj):
data_processing_params
=
{
data_processing_params
=
{
"train_epochs"
:
flag_obj
.
num_train_epochs
,
"train_epochs"
:
flag_obj
.
num_train_epochs
,
"batch_size"
:
flag_obj
.
prebatch_size
,
"batch_size"
:
flag_obj
.
train_
prebatch_size
,
"eval_batch_size"
:
flag_obj
.
prebatch_size
,
"eval_batch_size"
:
flag_obj
.
eval_
prebatch_size
,
"batches_per_step"
:
1
,
"batches_per_step"
:
1
,
"stream_files"
:
True
,
"stream_files"
:
True
,
"num_neg"
:
flag_obj
.
num_negative_samples
,
"num_neg"
:
flag_obj
.
num_negative_samples
,
...
...
official/recommendation/ncf_common.py
View file @
a35e09d2
...
@@ -154,8 +154,10 @@ def define_ncf_flags():
...
@@ -154,8 +154,10 @@ def define_ncf_flags():
intra_op
=
False
,
intra_op
=
False
,
synthetic_data
=
True
,
synthetic_data
=
True
,
max_train_steps
=
False
,
max_train_steps
=
False
,
dtype
=
Fals
e
,
dtype
=
Tru
e
,
all_reduce_alg
=
False
,
all_reduce_alg
=
False
,
loss_scale
=
True
,
dynamic_loss_scale
=
True
,
enable_xla
=
True
,
enable_xla
=
True
,
force_v2_in_keras_compile
=
True
force_v2_in_keras_compile
=
True
)
)
...
...
official/recommendation/ncf_input_pipeline.py
View file @
a35e09d2
...
@@ -21,7 +21,6 @@ from __future__ import print_function
...
@@ -21,7 +21,6 @@ from __future__ import print_function
import
functools
import
functools
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
import
numpy
as
np
import
tensorflow.compat.v2
as
tf
import
tensorflow.compat.v2
as
tf
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
...
@@ -42,6 +41,9 @@ def create_dataset_from_tf_record_files(input_file_pattern,
...
@@ -42,6 +41,9 @@ def create_dataset_from_tf_record_files(input_file_pattern,
def
make_dataset
(
files_dataset
,
shard_index
):
def
make_dataset
(
files_dataset
,
shard_index
):
"""Returns dataset for sharded tf record files."""
"""Returns dataset for sharded tf record files."""
if
pre_batch_size
!=
batch_size
:
raise
ValueError
(
"Pre-batch ({}) size is not equal to batch "
"size ({})"
.
format
(
pre_batch_size
,
batch_size
))
files_dataset
=
files_dataset
.
shard
(
NUM_SHARDS
,
shard_index
)
files_dataset
=
files_dataset
.
shard
(
NUM_SHARDS
,
shard_index
)
dataset
=
files_dataset
.
interleave
(
tf
.
data
.
TFRecordDataset
)
dataset
=
files_dataset
.
interleave
(
tf
.
data
.
TFRecordDataset
)
decode_fn
=
functools
.
partial
(
decode_fn
=
functools
.
partial
(
...
@@ -50,8 +52,6 @@ def create_dataset_from_tf_record_files(input_file_pattern,
...
@@ -50,8 +52,6 @@ def create_dataset_from_tf_record_files(input_file_pattern,
is_training
=
is_training
)
is_training
=
is_training
)
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
decode_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
decode_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
unbatch
())
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
return
dataset
return
dataset
dataset
=
tf
.
data
.
Dataset
.
range
(
NUM_SHARDS
)
dataset
=
tf
.
data
.
Dataset
.
range
(
NUM_SHARDS
)
...
...
official/recommendation/ncf_keras_benchmark.py
View file @
a35e09d2
...
@@ -117,7 +117,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -117,7 +117,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
"""
"""
self
.
_run_and_report_benchmark
(
hr_at_10_min
=
0.61
)
self
.
_run_and_report_benchmark
(
hr_at_10_min
=
0.61
)
def
_run_and_report_benchmark
(
self
,
hr_at_10_min
=
0.630
,
hr_at_10_max
=
0.64
0
):
def
_run_and_report_benchmark
(
self
,
hr_at_10_min
=
0.630
,
hr_at_10_max
=
0.64
5
):
"""Run test and report results.
"""Run test and report results.
Note: Target is 0.635, but some runs are below that level. Until we have
Note: Target is 0.635, but some runs are below that level. Until we have
...
@@ -263,6 +263,15 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -263,6 +263,15 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
train_epochs
=
7
FLAGS
.
train_epochs
=
7
self
.
_run_and_report_benchmark_mlperf_like
()
self
.
_run_and_report_benchmark_mlperf_like
()
def
benchmark_1_gpu_ctl_fp16_mlperf_like
(
self
):
"""1 GPU using CTL."""
self
.
_setup
()
FLAGS
.
keras_use_ctl
=
True
FLAGS
.
train_epochs
=
7
FLAGS
.
dtype
=
'fp16'
FLAGS
.
loss_scale
=
8192
self
.
_run_and_report_benchmark_mlperf_like
()
def
benchmark_1_gpu_ctl_run_eagerly_mlperf_like
(
self
):
def
benchmark_1_gpu_ctl_run_eagerly_mlperf_like
(
self
):
"""1 GPU using CTL with eager and distribution strategy."""
"""1 GPU using CTL with eager and distribution strategy."""
self
.
_setup
()
self
.
_setup
()
...
@@ -279,6 +288,16 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -279,6 +288,16 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
train_epochs
=
7
FLAGS
.
train_epochs
=
7
self
.
_run_and_report_benchmark_mlperf_like
()
self
.
_run_and_report_benchmark_mlperf_like
()
def
benchmark_xla_1_gpu_ctl_fp16_mlperf_like
(
self
):
"""1 GPU using CTL with XLA."""
self
.
_setup
()
FLAGS
.
keras_use_ctl
=
True
FLAGS
.
enable_xla
=
True
FLAGS
.
train_epochs
=
7
FLAGS
.
dtype
=
'fp16'
FLAGS
.
loss_scale
=
8192
self
.
_run_and_report_benchmark_mlperf_like
()
def
benchmark_8_gpu_mlperf_like
(
self
):
def
benchmark_8_gpu_mlperf_like
(
self
):
"""8 GPU using keras fit/compile."""
"""8 GPU using keras fit/compile."""
self
.
_setup
()
self
.
_setup
()
...
...
official/recommendation/ncf_keras_main.py
View file @
a35e09d2
...
@@ -42,6 +42,7 @@ from official.utils.logs import mlperf_helper
...
@@ -42,6 +42,7 @@ from official.utils.logs import mlperf_helper
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
tpu_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -267,6 +268,12 @@ def run_ncf(_):
...
@@ -267,6 +268,12 @@ def run_ncf(_):
beta_1
=
params
[
"beta1"
],
beta_1
=
params
[
"beta1"
],
beta_2
=
params
[
"beta2"
],
beta_2
=
params
[
"beta2"
],
epsilon
=
params
[
"epsilon"
])
epsilon
=
params
[
"epsilon"
])
if
FLAGS
.
dtype
==
"fp16"
:
optimizer
=
\
tf
.
compat
.
v1
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
optimizer
,
loss_scale
=
flags_core
.
get_loss_scale
(
FLAGS
,
default_for_fp16
=
"dynamic"
))
if
params
[
"keras_use_ctl"
]:
if
params
[
"keras_use_ctl"
]:
train_loss
,
eval_results
=
run_ncf_custom_training
(
train_loss
,
eval_results
=
run_ncf_custom_training
(
...
@@ -371,8 +378,12 @@ def run_ncf_custom_training(params,
...
@@ -371,8 +378,12 @@ def run_ncf_custom_training(params,
softmax_logits
,
softmax_logits
,
sample_weight
=
features
[
rconst
.
VALID_POINT_MASK
])
sample_weight
=
features
[
rconst
.
VALID_POINT_MASK
])
loss
*=
(
1.0
/
params
[
"batch_size"
])
loss
*=
(
1.0
/
params
[
"batch_size"
])
if
FLAGS
.
dtype
==
"fp16"
:
loss
=
optimizer
.
get_scaled_loss
(
loss
)
grads
=
tape
.
gradient
(
loss
,
keras_model
.
trainable_variables
)
grads
=
tape
.
gradient
(
loss
,
keras_model
.
trainable_variables
)
if
FLAGS
.
dtype
==
"fp16"
:
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
# Converting gradients to dense form helps in perf on GPU for NCF
# Converting gradients to dense form helps in perf on GPU for NCF
grads
=
neumf_model
.
sparse_to_dense_grads
(
grads
=
neumf_model
.
sparse_to_dense_grads
(
list
(
zip
(
grads
,
keras_model
.
trainable_variables
)))
list
(
zip
(
grads
,
keras_model
.
trainable_variables
)))
...
...
official/resnet/ctl/ctl_common.py
View file @
a35e09d2
...
@@ -27,3 +27,6 @@ def define_ctl_flags():
...
@@ -27,3 +27,6 @@ def define_ctl_flags():
flags
.
DEFINE_boolean
(
name
=
'use_tf_function'
,
default
=
True
,
flags
.
DEFINE_boolean
(
name
=
'use_tf_function'
,
default
=
True
,
help
=
'Wrap the train and test step inside a '
help
=
'Wrap the train and test step inside a '
'tf.function.'
)
'tf.function.'
)
flags
.
DEFINE_boolean
(
name
=
'single_l2_loss_op'
,
default
=
False
,
help
=
'Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.'
)
official/resnet/ctl/ctl_imagenet_benchmark.py
View file @
a35e09d2
...
@@ -22,7 +22,7 @@ import time
...
@@ -22,7 +22,7 @@ import time
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.
resnet.keras
import
keras_
common
from
official.
vision.image_classification
import
common
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_common
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
...
@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
flag_methods
=
[
flag_methods
=
[
ctl_common
.
define_ctl_flags
,
ctl_common
.
define_ctl_flags
,
keras_
common
.
define_keras_flags
common
.
define_keras_flags
]
]
self
.
data_dir
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
self
.
data_dir
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
...
@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
flag_methods
=
[
flag_methods
=
[
ctl_common
.
define_ctl_flags
,
ctl_common
.
define_ctl_flags
,
keras_
common
.
define_keras_flags
common
.
define_keras_flags
]
]
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
...
@@ -215,6 +215,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -215,6 +215,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_eager'
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_eager'
)
FLAGS
.
batch_size
=
64
FLAGS
.
batch_size
=
64
FLAGS
.
use_tf_function
=
False
FLAGS
.
use_tf_function
=
False
FLAGS
.
single_l2_loss_op
=
True
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
def
benchmark_8_gpu
(
self
):
def
benchmark_8_gpu
(
self
):
...
...
official/resnet/ctl/ctl_imagenet_main.py
View file @
a35e09d2
...
@@ -24,10 +24,10 @@ from absl import logging
...
@@ -24,10 +24,10 @@ from absl import logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_common
from
official.
resnet.keras
import
imagenet_preprocessing
from
official.
vision.image_classification
import
imagenet_preprocessing
from
official.
resnet.keras
import
keras_
common
from
official.
vision.image_classification
import
common
from
official.
resnet.keras
import
keras
_imagenet_main
from
official.
vision.image_classification
import
resnet
_imagenet_main
from
official.
resnet.keras
import
resnet_model
from
official.
vision.image_classification
import
resnet_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
...
@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy):
...
@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets."""
"""Returns the test and train input datasets."""
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
flags_obj
.
use_synthetic_data
:
if
flags_obj
.
use_synthetic_data
:
input_fn
=
keras_
common
.
get_synth_input_fn
(
input_fn
=
common
.
get_synth_input_fn
(
height
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
height
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_preprocessing
.
NUM_CHANNELS
,
num_channels
=
imagenet_preprocessing
.
NUM_CHANNELS
,
...
@@ -137,6 +137,10 @@ def run(flags_obj):
...
@@ -137,6 +137,10 @@ def run(flags_obj):
Returns:
Returns:
Dictionary of training and eval stats.
Dictionary of training and eval stats.
"""
"""
keras_utils
.
set_session_config
(
enable_eager
=
flags_obj
.
enable_eager
,
enable_xla
=
flags_obj
.
enable_xla
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
# TODO(anj-s): Set data_format without using Keras.
# TODO(anj-s): Set data_format without using Keras.
...
@@ -163,10 +167,11 @@ def run(flags_obj):
...
@@ -163,10 +167,11 @@ def run(flags_obj):
with
strategy_scope
:
with
strategy_scope
:
model
=
resnet_model
.
resnet50
(
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_preprocessing
.
NUM_CLASSES
,
num_classes
=
imagenet_preprocessing
.
NUM_CLASSES
,
dtype
=
dtype
,
batch_size
=
flags_obj
.
batch_size
)
dtype
=
dtype
,
batch_size
=
flags_obj
.
batch_size
,
use_l2_regularizer
=
not
flags_obj
.
single_l2_loss_op
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
keras_
common
.
BASE_LEARNING_RATE
,
momentum
=
0.9
,
learning_rate
=
common
.
BASE_LEARNING_RATE
,
momentum
=
0.9
,
nesterov
=
True
)
nesterov
=
True
)
training_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
training_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
...
@@ -175,6 +180,8 @@ def run(flags_obj):
...
@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
test_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'test_accuracy'
,
dtype
=
tf
.
float32
)
'test_accuracy'
,
dtype
=
tf
.
float32
)
trainable_variables
=
model
.
trainable_variables
def
train_step
(
train_ds_inputs
):
def
train_step
(
train_ds_inputs
):
"""Training StepFn."""
"""Training StepFn."""
def
step_fn
(
inputs
):
def
step_fn
(
inputs
):
...
@@ -185,13 +192,22 @@ def run(flags_obj):
...
@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
prediction_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
logits
)
labels
,
logits
)
loss1
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
flags_obj
.
batch_size
)
loss
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
flags_obj
.
batch_size
)
loss2
=
(
tf
.
reduce_sum
(
model
.
losses
)
/
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
)
loss
=
loss1
+
loss2
if
flags_obj
.
single_l2_loss_op
:
filtered_variables
=
[
grads
=
tape
.
gradient
(
loss
,
model
.
trainable_variables
)
tf
.
reshape
(
v
,
(
-
1
,))
optimizer
.
apply_gradients
(
zip
(
grads
,
model
.
trainable_variables
))
for
v
in
trainable_variables
if
'bn'
not
in
v
.
name
]
l2_loss
=
resnet_model
.
L2_WEIGHT_DECAY
*
2
*
tf
.
nn
.
l2_loss
(
tf
.
concat
(
filtered_variables
,
axis
=
0
))
loss
+=
(
l2_loss
/
num_replicas
)
else
:
loss
+=
(
tf
.
reduce_sum
(
model
.
losses
)
/
num_replicas
)
grads
=
tape
.
gradient
(
loss
,
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
trainable_variables
))
training_accuracy
.
update_state
(
labels
,
logits
)
training_accuracy
.
update_state
(
labels
,
logits
)
return
loss
return
loss
...
@@ -232,7 +248,7 @@ def run(flags_obj):
...
@@ -232,7 +248,7 @@ def run(flags_obj):
training_accuracy
.
reset_states
()
training_accuracy
.
reset_states
()
for
step
in
range
(
train_steps
):
for
step
in
range
(
train_steps
):
optimizer
.
lr
=
keras
_imagenet_main
.
learning_rate_schedule
(
optimizer
.
lr
=
resnet
_imagenet_main
.
learning_rate_schedule
(
epoch
,
step
,
train_steps
,
flags_obj
.
batch_size
)
epoch
,
step
,
train_steps
,
flags_obj
.
batch_size
)
time_callback
.
on_batch_begin
(
step
+
epoch
*
train_steps
)
time_callback
.
on_batch_begin
(
step
+
epoch
*
train_steps
)
...
@@ -281,7 +297,7 @@ def main(_):
...
@@ -281,7 +297,7 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logging
.
set_verbosity
(
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
keras_
common
.
define_keras_flags
()
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
ctl_common
.
define_ctl_flags
()
flags
.
adopt_module_key_flags
(
keras_common
)
flags
.
adopt_module_key_flags
(
keras_common
)
flags
.
adopt_module_key_flags
(
ctl_common
)
flags
.
adopt_module_key_flags
(
ctl_common
)
...
...
official/resnet/ctl/ctl_imagenet_test.py
View file @
a35e09d2
...
@@ -25,8 +25,8 @@ from tensorflow.python.eager import context
...
@@ -25,8 +25,8 @@ from tensorflow.python.eager import context
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
googletest
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.
resnet.keras
import
imagenet_preprocessing
from
official.
vision.image_classification
import
imagenet_preprocessing
from
official.
resnet.keras
import
keras_
common
from
official.
vision.image_classification
import
common
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
...
@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase):
...
@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
CtlImagenetTest
,
cls
).
setUpClass
()
super
(
CtlImagenetTest
,
cls
).
setUpClass
()
keras_
common
.
define_keras_flags
()
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
ctl_common
.
define_ctl_flags
()
def
setUp
(
self
):
def
setUp
(
self
):
...
...
official/resnet/keras/__init__.py
deleted
100644 → 0
View file @
d5722dcd
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared Keras ResNet modules into this module.
The TensorFlow official Keras models are moved under
official/vision/image_classification
In order to be backward compatible with models that directly import its modules,
we import the Keras ResNet modules under official.resnet.keras.
New TF models should not depend on modules directly under this path.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
official.vision.image_classification
import
cifar_preprocessing
from
official.vision.image_classification
import
common
as
keras_common
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_cifar_main
as
keras_cifar_main
from
official.vision.image_classification
import
resnet_cifar_model
from
official.vision.image_classification
import
resnet_imagenet_main
as
keras_imagenet_main
from
official.vision.image_classification
import
resnet_model
del
absolute_import
del
division
del
print_function
official/staging/shakespeare/shakespeare_benchmark.py
View file @
a35e09d2
...
@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
...
@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
flag_methods
=
[
shakespeare_main
.
define_flags
])
flag_methods
=
[
shakespeare_main
.
define_flags
])
def
_run_and_report_benchmark
(
self
,
def
_run_and_report_benchmark
(
self
,
top_1_train_min
=
0.9
23
,
top_1_train_min
=
0.9
1
,
top_1_train_max
=
0.9
3
,
top_1_train_max
=
0.9
4
,
warmup
=
1
,
warmup
=
1
,
log_steps
=
100
):
log_steps
=
100
):
"""Report benchmark results by writing to local protobuf file.
"""Report benchmark results by writing to local protobuf file.
...
...
official/transformer/model/beam_search.py
View file @
a35e09d2
...
@@ -79,8 +79,41 @@ class _StateKeys(object):
...
@@ -79,8 +79,41 @@ class _StateKeys(object):
class
SequenceBeamSearch
(
object
):
class
SequenceBeamSearch
(
object
):
"""Implementation of beam search loop."""
"""Implementation of beam search loop."""
def
__init__
(
self
,
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
def
__init__
(
self
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
=
tf
.
float32
):
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
=
tf
.
float32
):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and the updated cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
updated cache -> A nested dictionary with the same structure as the
input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
batch_size: An integer, the decode batch size.
beam_size: An integer, number of beams for beam search.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum number of steps to decode
a sequence.
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
...
@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
self
.
alpha
=
alpha
self
.
alpha
=
alpha
self
.
max_decode_length
=
max_decode_length
self
.
max_decode_length
=
max_decode_length
self
.
eos_id
=
eos_id
self
.
eos_id
=
eos_id
self
.
padded_decode
=
padded_decode
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
def
search
(
self
,
initial_ids
,
initial_cache
):
def
search
(
self
,
initial_ids
,
initial_cache
):
...
@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
...
@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
# Create alive sequence with shape [batch_size, beam_size, 1]
# Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq
=
_expand_to_beam_size
(
initial_ids
,
self
.
beam_size
)
alive_seq
=
_expand_to_beam_size
(
initial_ids
,
self
.
beam_size
)
alive_seq
=
tf
.
expand_dims
(
alive_seq
,
axis
=
2
)
alive_seq
=
tf
.
expand_dims
(
alive_seq
,
axis
=
2
)
if
self
.
padded_decode
:
alive_seq
=
tf
.
tile
(
alive_seq
,
[
1
,
1
,
self
.
max_decode_length
+
1
])
# Create tensor for storing initial log probabilities.
# Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0
# Assume initial_ids are prob 1.0
...
@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
...
@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
# 1) the dimension's value is a tensor that remains the same but may
# 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size).
# depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations.
# 2) the dimension may have different values on different iterations.
state_shape_invariants
=
{
if
self
.
padded_decode
:
_StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
state_shape_invariants
=
{
_StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
CUR_INDEX
:
_StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
tf
.
TensorShape
([]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_StateKeys
.
ALIVE_SEQ
:
_get_shape_keep_last_dim
,
alive_cache
),
tf
.
TensorShape
(
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
[
self
.
batch_size
,
self
.
beam_size
,
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
self
.
max_decode_length
+
1
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
])
_StateKeys
.
ALIVE_LOG_PROBS
:
}
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_get_shape
,
alive_cache
),
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
(
[
self
.
batch_size
,
self
.
beam_size
,
self
.
max_decode_length
+
1
]),
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
])
}
else
:
state_shape_invariants
=
{
_StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
_StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_get_shape_keep_last_dim
,
alive_cache
),
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
])
}
return
state
,
state_shape_invariants
return
state
,
state_shape_invariants
...
@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
...
@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
# Get logits for the next candidate IDs for the alive sequences. Get the new
# Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time.
# cache values at the same time.
flat_ids
=
_flatten_beam_dim
(
alive_seq
)
# [batch_size * beam_size]
if
self
.
padded_decode
:
flat_ids
=
tf
.
reshape
(
tf
.
slice
(
alive_seq
,
[
0
,
0
,
i
],
[
self
.
batch_size
,
self
.
beam_size
,
1
]),
[
self
.
batch_size
*
self
.
beam_size
,
-
1
])
else
:
flat_ids
=
_flatten_beam_dim
(
alive_seq
)
# [batch_size * beam_size]
flat_cache
=
nest
.
map_structure
(
_flatten_beam_dim
,
alive_cache
)
flat_cache
=
nest
.
map_structure
(
_flatten_beam_dim
,
alive_cache
)
flat_logits
,
flat_cache
=
self
.
symbols_to_logits_fn
(
flat_ids
,
i
,
flat_cache
)
flat_logits
,
flat_cache
=
self
.
symbols_to_logits_fn
(
flat_ids
,
i
,
flat_cache
)
...
@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
...
@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
# Append the most probable IDs to the topk sequences
# Append the most probable IDs to the topk sequences
topk_ids
=
topk_indices
%
self
.
vocab_size
topk_ids
=
topk_indices
%
self
.
vocab_size
topk_ids
=
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)
if
self
.
padded_decode
:
topk_seq
=
tf
.
concat
([
topk_seq
,
topk_ids
],
axis
=
2
)
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
2
,
0
,
1
])
topk_seq
=
tf
.
tensor_scatter_update
(
topk_seq
,
[
i
+
1
],
topk_ids
)
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
1
,
2
,
0
])
else
:
topk_ids
=
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)
topk_seq
=
tf
.
concat
([
topk_seq
,
topk_ids
],
axis
=
2
)
return
topk_seq
,
topk_log_probs
,
new_cache
return
topk_seq
,
topk_log_probs
,
new_cache
def
_get_new_alive_state
(
self
,
new_seq
,
new_log_probs
,
new_cache
):
def
_get_new_alive_state
(
self
,
new_seq
,
new_log_probs
,
new_cache
):
...
@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
...
@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
# First append a column of 0-ids to finished_seq to increment the length.
# First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1]
# New shape of finished_seq: [batch_size, beam_size, i + 1]
finished_seq
=
tf
.
concat
(
if
not
self
.
padded_decode
:
[
finished_seq
,
finished_seq
=
tf
.
concat
([
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
,
1
],
tf
.
int32
)],
axis
=
2
)
finished_seq
,
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
,
1
],
tf
.
int32
)
],
axis
=
2
)
# Calculate new seq scores from log probabilities.
# Calculate new seq scores from log probabilities.
length_norm
=
_length_normalization
(
self
.
alpha
,
i
+
1
,
dtype
=
self
.
dtype
)
length_norm
=
_length_normalization
(
self
.
alpha
,
i
+
1
,
dtype
=
self
.
dtype
)
...
@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
...
@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
def
sequence_beam_search
(
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
):
alpha
,
max_decode_length
,
eos_id
,
padded_decode
=
False
):
"""Search for sequence of subtoken ids with the largest probability.
"""Search for sequence of subtoken ids with the largest probability.
Args:
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar)
index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache.
The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size]
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache
new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item.
inputted cache.
int32 tensor with shape [batch_size]
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information
each batch item.
vocab_size: int size of tokens
initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams
information.
alpha: float defining the strength of length normalization
vocab_size: An integer, the size of the vocabulary, used for topk
max_decode_length: maximum length to decoded sequence
computation.
eos_id: int id of eos token, used to determine when a sequence has finished
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
Returns:
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
sequence scores [batch_size, beam_size]
"""
"""
batch_size
=
tf
.
shape
(
initial_ids
)[
0
]
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
)
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
...
@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
return
tf
.
TensorShape
(
shape_list
)
return
tf
.
TensorShape
(
shape_list
)
def
_get_shape
(
tensor
):
"""Return the shape of the input tensor."""
return
tf
.
TensorShape
(
_shape_list
(
tensor
))
def
_flatten_beam_dim
(
tensor
):
def
_flatten_beam_dim
(
tensor
):
"""Reshapes first two dimensions in to single dimension.
"""Reshapes first two dimensions in to single dimension.
...
...
official/transformer/transformer_main.py
View file @
a35e09d2
...
@@ -32,6 +32,7 @@ from absl import flags
...
@@ -32,6 +32,7 @@ from absl import flags
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
from
official.r1.utils
import
export
from
official.transformer
import
compute_bleu
from
official.transformer
import
compute_bleu
from
official.transformer
import
translate
from
official.transformer
import
translate
from
official.transformer.model
import
model_params
from
official.transformer.model
import
model_params
...
@@ -41,7 +42,6 @@ from official.transformer.utils import metrics
...
@@ -41,7 +42,6 @@ from official.transformer.utils import metrics
from
official.transformer.utils
import
schedule
from
official.transformer.utils
import
schedule
from
official.transformer.utils
import
tokenizer
from
official.transformer.utils
import
tokenizer
from
official.utils.accelerator
import
tpu
as
tpu_util
from
official.utils.accelerator
import
tpu
as
tpu_util
from
official.utils.export
import
export
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
...
@@ -56,7 +56,7 @@ PARAMS_MAP = {
...
@@ -56,7 +56,7 @@ PARAMS_MAP = {
DEFAULT_TRAIN_EPOCHS
=
10
DEFAULT_TRAIN_EPOCHS
=
10
INF
=
int
(
1e9
)
INF
=
1000000000
#
1e9
BLEU_DIR
=
"bleu"
BLEU_DIR
=
"bleu"
# Dictionary containing tensors that are logged by the logging hooks. Each item
# Dictionary containing tensors that are logged by the logging hooks. Each item
...
...
official/transformer/v2/attention_layer.py
View file @
a35e09d2
...
@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
...
@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
x
=
tf
.
transpose
(
x
,
[
0
,
2
,
1
,
3
])
# --> [batch, length, num_heads, depth]
x
=
tf
.
transpose
(
x
,
[
0
,
2
,
1
,
3
])
# --> [batch, length, num_heads, depth]
return
tf
.
reshape
(
x
,
[
batch_size
,
length
,
self
.
hidden_size
])
return
tf
.
reshape
(
x
,
[
batch_size
,
length
,
self
.
hidden_size
])
def
call
(
self
,
x
,
y
,
bias
,
training
,
cache
=
None
):
def
call
(
self
,
x
,
y
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
"""Apply attention mechanism to x and y.
"""Apply attention mechanism to x and y.
Args:
Args:
x: a tensor with shape [batch_size, length_x, hidden_size]
x: A tensor with shape [batch_size, length_x, hidden_size].
y: a tensor with shape [batch_size, length_y, hidden_size]
y: A tensor with shape [batch_size, length_y, hidden_size].
bias: attention bias that will be added to the result of the dot product.
bias: A bool, the attention bias that will be added to the result of the
training: boolean, whether in training mode or not.
dot product.
cache: (Used during prediction) dictionary with tensors containing results
training: A bool, whether in training mode or not.
of previous attentions. The dictionary must have the items:
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
{"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}
"v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length.
where i is the current decoded length.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns:
Returns:
Attention layer output with shape [batch_size, length_x, hidden_size]
Attention layer output with shape [batch_size, length_x, hidden_size]
"""
"""
# Linearly project the query
(q)
, key
(k)
and value
(v)
using different
# Linearly project the query, key and value using different
learned
#
learned
projections. This is in preparation of splitting them into
# projections. This is in preparation of splitting them into
multiple
#
multiple
heads. Multi-head attention uses multiple queries, keys, and
# heads. Multi-head attention uses multiple queries, keys, and
values
#
values
rather than regular attention (which uses a single q, k, v).
# rather than regular attention (which uses a single q
uery
, k
ey
, v
alue
).
q
=
self
.
q_dense_layer
(
x
)
q
uery
=
self
.
q_dense_layer
(
x
)
k
=
self
.
k_dense_layer
(
y
)
k
ey
=
self
.
k_dense_layer
(
y
)
v
=
self
.
v_dense_layer
(
y
)
v
alue
=
self
.
v_dense_layer
(
y
)
if
cache
is
not
None
:
if
cache
is
not
None
:
# Combine cached keys and values with new keys and values.
# Combine cached keys and values with new keys and values.
k
=
tf
.
concat
([
tf
.
cast
(
cache
[
"k"
],
k
.
dtype
),
k
],
axis
=
1
)
if
decode_loop_step
is
not
None
:
v
=
tf
.
concat
([
tf
.
cast
(
cache
[
"v"
],
k
.
dtype
),
v
],
axis
=
1
)
cache_k_shape
=
cache
[
"k"
].
shape
.
as_list
()
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
cache_k_shape
[
1
],
dtype
=
key
.
dtype
),
[
1
,
cache_k_shape
[
1
],
1
])
key
=
cache
[
"k"
]
+
key
*
indices
cache_v_shape
=
cache
[
"v"
].
shape
.
as_list
()
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
cache_v_shape
[
1
],
dtype
=
value
.
dtype
),
[
1
,
cache_v_shape
[
1
],
1
])
value
=
cache
[
"v"
]
+
value
*
indices
else
:
key
=
tf
.
concat
([
tf
.
cast
(
cache
[
"k"
],
key
.
dtype
),
key
],
axis
=
1
)
value
=
tf
.
concat
([
tf
.
cast
(
cache
[
"v"
],
value
.
dtype
),
value
],
axis
=
1
)
# Update cache
# Update cache
cache
[
"k"
]
=
k
cache
[
"k"
]
=
k
ey
cache
[
"v"
]
=
v
cache
[
"v"
]
=
v
alue
# Split q, k, v into heads.
# Split q
uery
, k
ey
, v
alue
into heads.
q
=
self
.
split_heads
(
q
)
q
uery
=
self
.
split_heads
(
q
uery
)
k
=
self
.
split_heads
(
k
)
k
ey
=
self
.
split_heads
(
k
ey
)
v
=
self
.
split_heads
(
v
)
v
alue
=
self
.
split_heads
(
v
alue
)
# Scale q to prevent the dot product between q and k from growing too large.
# Scale query to prevent the dot product between query and key from growing
# too large.
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
q
*=
depth
**
-
0.5
q
uery
*=
depth
**
-
0.5
# Calculate dot product attention
# Calculate dot product attention
logits
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
q
uery
,
k
ey
,
transpose_b
=
True
)
logits
+=
bias
logits
+=
bias
# Note that softmax internally performs math operations using float32
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
# for numeric stability. When training with float16, we keep the input
...
@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
...
@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
weights
=
tf
.
nn
.
softmax
(
logits
,
name
=
"attention_weights"
)
weights
=
tf
.
nn
.
softmax
(
logits
,
name
=
"attention_weights"
)
if
training
:
if
training
:
weights
=
tf
.
nn
.
dropout
(
weights
,
rate
=
self
.
attention_dropout
)
weights
=
tf
.
nn
.
dropout
(
weights
,
rate
=
self
.
attention_dropout
)
attention_output
=
tf
.
matmul
(
weights
,
v
)
attention_output
=
tf
.
matmul
(
weights
,
v
alue
)
# Recombine heads --> [batch_size, length, hidden_size]
# Recombine heads --> [batch_size, length, hidden_size]
attention_output
=
self
.
combine_heads
(
attention_output
)
attention_output
=
self
.
combine_heads
(
attention_output
)
...
@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
...
@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
class
SelfAttention
(
Attention
):
class
SelfAttention
(
Attention
):
"""Multiheaded self-attention layer."""
"""Multiheaded self-attention layer."""
def
call
(
self
,
x
,
bias
,
training
,
cache
=
None
):
def
call
(
self
,
x
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
return
super
(
SelfAttention
,
self
).
call
(
x
,
x
,
bias
,
training
,
cache
)
return
super
(
SelfAttention
,
self
).
call
(
x
,
x
,
bias
,
training
,
cache
,
decode_loop_step
)
official/transformer/v2/beam_search.py
View file @
a35e09d2
...
@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
...
@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
return
finished_seq
,
finished_scores
return
finished_seq
,
finished_scores
def
sequence_beam_search
(
def
sequence_beam_search
(
symbols_to_logits_fn
,
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
initial_ids
,
alpha
,
max_decode_length
,
eos_id
,
dtype
=
"float32"
):
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
=
False
,
dtype
=
"float32"
):
"""Search for sequence of subtoken ids with the largest probability.
"""Search for sequence of subtoken ids with the largest probability.
Args:
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar)
index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache.
The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size]
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache
new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item.
inputted cache.
int32 tensor with shape [batch_size]
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information
each batch item.
vocab_size: int size of tokens
initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams
information.
alpha: float defining the strength of length normalization
vocab_size: An integer, the size of tokens.
max_decode_length: maximum length to decoded sequence
beam_size: An integer, the number of beams.
eos_id: int id of eos token, used to determine when a sequence has finished,
alpha: A float, defining the strength of length normalization.
dtype: The dtype to use.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns:
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
sequence scores [batch_size, beam_size]
"""
"""
batch_size
=
tf
.
shape
(
initial_ids
)[
0
]
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
if
misc
.
is_v2
():
if
misc
.
is_v2
():
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
)
padded_decode
,
dtype
)
else
:
else
:
sbs
=
v1
.
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
sbs
=
v1
.
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
)
padded_decode
,
dtype
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
...
official/transformer/v2/embedding_layer.py
View file @
a35e09d2
...
@@ -24,24 +24,14 @@ import tensorflow as tf
...
@@ -24,24 +24,14 @@ import tensorflow as tf
class
EmbeddingSharedWeights
(
tf
.
keras
.
layers
.
Layer
):
class
EmbeddingSharedWeights
(
tf
.
keras
.
layers
.
Layer
):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def
__init__
(
self
,
vocab_size
,
hidden_size
,
dtype
=
None
):
def
__init__
(
self
,
vocab_size
,
hidden_size
):
"""Specify characteristic parameters of embedding layer.
"""Specify characteristic parameters of embedding layer.
Args:
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
"""
"""
if
dtype
==
tf
.
float16
:
super
(
EmbeddingSharedWeights
,
self
).
__init__
()
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"float16_with_float32_vars"
)
super
(
EmbeddingSharedWeights
,
self
).
__init__
(
dtype
=
dtype
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
...
@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self
.
shared_weights
=
self
.
add_weight
(
self
.
shared_weights
=
self
.
add_weight
(
"weights"
,
"weights"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
dtype
=
"float32"
,
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
super
(
EmbeddingSharedWeights
,
self
).
build
(
input_shape
)
super
(
EmbeddingSharedWeights
,
self
).
build
(
input_shape
)
...
...
official/transformer/v2/misc.py
View file @
a35e09d2
...
@@ -192,6 +192,29 @@ def define_transformer_flags():
...
@@ -192,6 +192,29 @@ def define_transformer_flags():
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
'Whether the model runs in 2VM mode, Headless server and unit test '
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'
))
'all use 1VM config.'
))
flags
.
DEFINE_integer
(
name
=
'decode_batch_size'
,
default
=
32
,
help
=
flags_core
.
help_wrap
(
'Global batch size used for Transformer autoregressive decoding on '
'TPU.'
))
flags
.
DEFINE_integer
(
name
=
'decode_max_length'
,
default
=
97
,
help
=
flags_core
.
help_wrap
(
'Max sequence length of the decode/eval data. This is used by '
'Transformer autoregressive decoding on TPU to have minimum '
'paddings.'
))
flags
.
DEFINE_bool
(
name
=
'padded_decode'
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
'Whether the autoregressive decoding runs with input data padded to '
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
'set due the static shape requirement. Although CPU/GPU could also '
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'
))
flags_core
.
set_defaults
(
data_dir
=
'/tmp/translate_ende'
,
flags_core
.
set_defaults
(
data_dir
=
'/tmp/translate_ende'
,
model_dir
=
'/tmp/transformer_model'
,
model_dir
=
'/tmp/transformer_model'
,
...
...
official/transformer/v2/transformer.py
View file @
a35e09d2
...
@@ -49,8 +49,10 @@ def create_model(params, is_train):
...
@@ -49,8 +49,10 @@ def create_model(params, is_train):
label_smoothing
=
params
[
"label_smoothing"
]
label_smoothing
=
params
[
"label_smoothing"
]
if
params
[
"enable_metrics_in_training"
]:
if
params
[
"enable_metrics_in_training"
]:
logits
=
metrics
.
MetricLayer
(
vocab_size
)([
logits
,
targets
])
logits
=
metrics
.
MetricLayer
(
vocab_size
)([
logits
,
targets
])
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
)(
logits
)
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
,
dtype
=
tf
.
float32
)(
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss
=
metrics
.
transformer_loss
(
loss
=
metrics
.
transformer_loss
(
logits
,
targets
,
label_smoothing
,
vocab_size
)
logits
,
targets
,
label_smoothing
,
vocab_size
)
model
.
add_loss
(
loss
)
model
.
add_loss
(
loss
)
...
@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
...
@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
super
(
Transformer
,
self
).
__init__
(
name
=
name
)
super
(
Transformer
,
self
).
__init__
(
name
=
name
)
self
.
params
=
params
self
.
params
=
params
self
.
embedding_softmax_layer
=
embedding_layer
.
EmbeddingSharedWeights
(
self
.
embedding_softmax_layer
=
embedding_layer
.
EmbeddingSharedWeights
(
params
[
"vocab_size"
],
params
[
"hidden_size"
]
,
dtype
=
params
[
"dtype"
]
)
params
[
"vocab_size"
],
params
[
"hidden_size"
])
self
.
encoder_stack
=
EncoderStack
(
params
)
self
.
encoder_stack
=
EncoderStack
(
params
)
self
.
decoder_stack
=
DecoderStack
(
params
)
self
.
decoder_stack
=
DecoderStack
(
params
)
...
@@ -112,11 +114,22 @@ class Transformer(tf.keras.Model):
...
@@ -112,11 +114,22 @@ class Transformer(tf.keras.Model):
outputs: [batch_size, decoded length]
outputs: [batch_size, decoded length]
scores: [batch_size, float]}
scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32.
Even when float16 is used, the output tensor(s) are always float32.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
"""
if
len
(
inputs
)
==
2
:
if
len
(
inputs
)
==
2
:
inputs
,
targets
=
inputs
[
0
],
inputs
[
1
]
inputs
,
targets
=
inputs
[
0
],
inputs
[
1
]
else
:
else
:
inputs
,
targets
=
inputs
[
0
],
None
inputs
,
targets
=
inputs
[
0
],
None
if
self
.
params
[
"padded_decode"
]:
if
not
self
.
params
[
"num_replicas"
]:
raise
NotImplementedError
(
"Padded decoding on CPU/GPUs is not supported."
)
decode_batch_size
=
int
(
self
.
params
[
"decode_batch_size"
]
/
self
.
params
[
"num_replicas"
])
inputs
=
tf
.
reshape
(
inputs
,
[
decode_batch_size
,
self
.
params
[
"decode_max_length"
]])
# Variance scaling is used here because it seems to work in many problems.
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
# Other reasonable initializers may also work just as well.
...
@@ -225,13 +238,14 @@ class Transformer(tf.keras.Model):
...
@@ -225,13 +238,14 @@ class Transformer(tf.keras.Model):
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
max_decode_length
,
dtype
=
self
.
params
[
"dtype"
])
max_decode_length
,
dtype
=
self
.
params
[
"dtype"
])
# TODO(b/139770046): Refactor code with better naming of i.
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
"""Generate logits for next potential IDs.
"""Generate logits for next potential IDs.
Args:
Args:
ids: Current decoded sequences. int tensor with shape [batch_size *
ids: Current decoded sequences. int tensor with shape [batch_size *
beam_size, i + 1]
beam_size, i + 1]
.
i: Loop index
i: Loop index
.
cache: dictionary of values storing the encoder output, encoder-decoder
cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
attention bias, and previous decoder attention values.
...
@@ -245,16 +259,29 @@ class Transformer(tf.keras.Model):
...
@@ -245,16 +259,29 @@ class Transformer(tf.keras.Model):
# Preprocess decoder input by getting embeddings and adding timing signal.
# Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input
=
self
.
embedding_softmax_layer
(
decoder_input
)
decoder_input
=
self
.
embedding_softmax_layer
(
decoder_input
)
decoder_input
+=
timing_signal
[
i
:
i
+
1
]
self_attention_bias
=
decoder_self_attention_bias
[:,
:,
i
:
i
+
1
,
:
i
+
1
]
if
self
.
params
[
"padded_decode"
]:
timing_signal_shape
=
timing_signal
.
shape
.
as_list
()
decoder_input
+=
tf
.
slice
(
timing_signal
,
[
i
,
0
],
[
1
,
timing_signal_shape
[
1
]])
bias_shape
=
decoder_self_attention_bias
.
shape
.
as_list
()
self_attention_bias
=
tf
.
slice
(
decoder_self_attention_bias
,
[
0
,
0
,
i
,
0
],
[
bias_shape
[
0
],
bias_shape
[
1
],
1
,
bias_shape
[
3
]])
else
:
decoder_input
+=
timing_signal
[
i
:
i
+
1
]
self_attention_bias
=
decoder_self_attention_bias
[:,
:,
i
:
i
+
1
,
:
i
+
1
]
decoder_outputs
=
self
.
decoder_stack
(
decoder_outputs
=
self
.
decoder_stack
(
decoder_input
,
decoder_input
,
cache
.
get
(
"encoder_outputs"
),
cache
.
get
(
"encoder_outputs"
),
self_attention_bias
,
self_attention_bias
,
cache
.
get
(
"encoder_decoder_attention_bias"
),
cache
.
get
(
"encoder_decoder_attention_bias"
),
training
=
training
,
training
=
training
,
cache
=
cache
)
cache
=
cache
,
decode_loop_step
=
i
if
self
.
params
[
"padded_decode"
]
else
None
)
logits
=
self
.
embedding_softmax_layer
(
decoder_outputs
,
mode
=
"linear"
)
logits
=
self
.
embedding_softmax_layer
(
decoder_outputs
,
mode
=
"linear"
)
logits
=
tf
.
squeeze
(
logits
,
axis
=
[
1
])
logits
=
tf
.
squeeze
(
logits
,
axis
=
[
1
])
return
logits
,
cache
return
logits
,
cache
...
@@ -263,8 +290,12 @@ class Transformer(tf.keras.Model):
...
@@ -263,8 +290,12 @@ class Transformer(tf.keras.Model):
def
predict
(
self
,
encoder_outputs
,
encoder_decoder_attention_bias
,
training
):
def
predict
(
self
,
encoder_outputs
,
encoder_decoder_attention_bias
,
training
):
"""Return predicted sequence."""
"""Return predicted sequence."""
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
if
self
.
params
[
"padded_decode"
]:
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
batch_size
=
encoder_outputs
.
shape
.
as_list
()[
0
]
input_length
=
encoder_outputs
.
shape
.
as_list
()[
1
]
else
:
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
max_decode_length
=
input_length
+
self
.
params
[
"extra_decode_length"
]
max_decode_length
=
input_length
+
self
.
params
[
"extra_decode_length"
]
encoder_decoder_attention_bias
=
tf
.
cast
(
encoder_decoder_attention_bias
,
encoder_decoder_attention_bias
=
tf
.
cast
(
encoder_decoder_attention_bias
,
self
.
params
[
"dtype"
])
self
.
params
[
"dtype"
])
...
@@ -277,12 +308,20 @@ class Transformer(tf.keras.Model):
...
@@ -277,12 +308,20 @@ class Transformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer.
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
# pylint: disable=g-complex-comprehension
init_decode_length
=
(
max_decode_length
if
self
.
params
[
"padded_decode"
]
else
0
)
cache
=
{
cache
=
{
"layer_%d"
%
layer
:
{
"layer_%d"
%
layer
:
{
"k"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]],
"k"
:
dtype
=
self
.
params
[
"dtype"
]),
tf
.
zeros
([
"v"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]],
batch_size
,
init_decode_length
,
self
.
params
[
"hidden_size"
]
dtype
=
self
.
params
[
"dtype"
])
],
dtype
=
self
.
params
[
"dtype"
]),
"v"
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
self
.
params
[
"hidden_size"
]
],
dtype
=
self
.
params
[
"dtype"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
}
}
# pylint: enable=g-complex-comprehension
# pylint: enable=g-complex-comprehension
...
@@ -301,6 +340,7 @@ class Transformer(tf.keras.Model):
...
@@ -301,6 +340,7 @@ class Transformer(tf.keras.Model):
alpha
=
self
.
params
[
"alpha"
],
alpha
=
self
.
params
[
"alpha"
],
max_decode_length
=
max_decode_length
,
max_decode_length
=
max_decode_length
,
eos_id
=
EOS_ID
,
eos_id
=
EOS_ID
,
padded_decode
=
self
.
params
[
"padded_decode"
],
dtype
=
self
.
params
[
"dtype"
])
dtype
=
self
.
params
[
"dtype"
])
# Get the top sequence for each batch element
# Get the top sequence for each batch element
...
@@ -505,22 +545,28 @@ class DecoderStack(tf.keras.layers.Layer):
...
@@ -505,22 +545,28 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_self_attention_bias
,
decoder_self_attention_bias
,
attention_bias
,
attention_bias
,
training
,
training
,
cache
=
None
):
cache
=
None
,
decode_loop_step
=
None
):
"""Return the output of the decoder layer stacks.
"""Return the output of the decoder layer stacks.
Args:
Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size]
decoder_inputs: A tensor with shape
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size]
[batch_size, target_length, hidden_size].
decoder_self_attention_bias: bias for decoder self-attention layer. [1, 1,
encoder_outputs: A tensor with shape
target_len, target_length]
[batch_size, input_length, hidden_size]
attention_bias: bias for encoder-decoder attention layer. [batch_size, 1,
decoder_self_attention_bias: A tensor with shape
1, input_length]
[1, 1, target_len, target_length], the bias for decoder self-attention
training: boolean, whether in training mode or not.
layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
the bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels],
{layer_n: {"k":
A
tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]},
"v":
A
tensor with shape [batch_size, i, value_channels]},
...}
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns:
Returns:
Output of decoder layer stack.
Output of decoder layer stack.
...
@@ -540,7 +586,8 @@ class DecoderStack(tf.keras.layers.Layer):
...
@@ -540,7 +586,8 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_inputs
,
decoder_inputs
,
decoder_self_attention_bias
,
decoder_self_attention_bias
,
training
=
training
,
training
=
training
,
cache
=
layer_cache
)
cache
=
layer_cache
,
decode_loop_step
=
decode_loop_step
)
with
tf
.
name_scope
(
"encdec_attention"
):
with
tf
.
name_scope
(
"encdec_attention"
):
decoder_inputs
=
enc_dec_attention_layer
(
decoder_inputs
=
enc_dec_attention_layer
(
decoder_inputs
,
decoder_inputs
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment