Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a35e09d2
Unverified
Commit
a35e09d2
authored
Aug 28, 2019
by
Vinh Nguyen
Committed by
GitHub
Aug 28, 2019
Browse files
Merge branch 'master' into amp_resnet50
parents
d5722dcd
1f5a5e9d
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
390 additions
and
196 deletions
+390
-196
official/r1/utils/export_test.py
official/r1/utils/export_test.py
+1
-1
official/r1/wide_deep/movielens_dataset.py
official/r1/wide_deep/movielens_dataset.py
+1
-1
official/recommendation/create_ncf_data.py
official/recommendation/create_ncf_data.py
+2
-2
official/recommendation/ncf_common.py
official/recommendation/ncf_common.py
+3
-1
official/recommendation/ncf_input_pipeline.py
official/recommendation/ncf_input_pipeline.py
+3
-3
official/recommendation/ncf_keras_benchmark.py
official/recommendation/ncf_keras_benchmark.py
+20
-1
official/recommendation/ncf_keras_main.py
official/recommendation/ncf_keras_main.py
+11
-0
official/resnet/ctl/ctl_common.py
official/resnet/ctl/ctl_common.py
+3
-0
official/resnet/ctl/ctl_imagenet_benchmark.py
official/resnet/ctl/ctl_imagenet_benchmark.py
+4
-3
official/resnet/ctl/ctl_imagenet_main.py
official/resnet/ctl/ctl_imagenet_main.py
+32
-16
official/resnet/ctl/ctl_imagenet_test.py
official/resnet/ctl/ctl_imagenet_test.py
+3
-3
official/resnet/keras/__init__.py
official/resnet/keras/__init__.py
+0
-40
official/staging/shakespeare/shakespeare_benchmark.py
official/staging/shakespeare/shakespeare_benchmark.py
+2
-2
official/transformer/model/beam_search.py
official/transformer/model/beam_search.py
+126
-35
official/transformer/transformer_main.py
official/transformer/transformer_main.py
+2
-2
official/transformer/v2/attention_layer.py
official/transformer/v2/attention_layer.py
+45
-28
official/transformer/v2/beam_search.py
official/transformer/v2/beam_search.py
+36
-21
official/transformer/v2/embedding_layer.py
official/transformer/v2/embedding_layer.py
+2
-13
official/transformer/v2/misc.py
official/transformer/v2/misc.py
+23
-0
official/transformer/v2/transformer.py
official/transformer/v2/transformer.py
+71
-24
No files found.
official/utils/export
/export
_test.py
→
official/
r1/
utils/export_test.py
View file @
a35e09d2
...
...
@@ -20,7 +20,7 @@ from __future__ import print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils
.export
import
export
from
official.
r1.
utils
import
export
class
ExportUtilsTest
(
tf
.
test
.
TestCase
):
...
...
official/r1/wide_deep/movielens_dataset.py
View file @
a35e09d2
...
...
@@ -29,7 +29,7 @@ import tensorflow as tf
# pylint: enable=wrong-import-order
from
official.datasets
import
movielens
from
official.utils.data
import
file_io
from
official.
r1.
utils.data
import
file_io
from
official.utils.flags
import
core
as
flags_core
...
...
official/recommendation/create_ncf_data.py
View file @
a35e09d2
...
...
@@ -65,8 +65,8 @@ def prepare_raw_data(flag_obj):
data_processing_params
=
{
"train_epochs"
:
flag_obj
.
num_train_epochs
,
"batch_size"
:
flag_obj
.
prebatch_size
,
"eval_batch_size"
:
flag_obj
.
prebatch_size
,
"batch_size"
:
flag_obj
.
train_
prebatch_size
,
"eval_batch_size"
:
flag_obj
.
eval_
prebatch_size
,
"batches_per_step"
:
1
,
"stream_files"
:
True
,
"num_neg"
:
flag_obj
.
num_negative_samples
,
...
...
official/recommendation/ncf_common.py
View file @
a35e09d2
...
...
@@ -154,8 +154,10 @@ def define_ncf_flags():
intra_op
=
False
,
synthetic_data
=
True
,
max_train_steps
=
False
,
dtype
=
Fals
e
,
dtype
=
Tru
e
,
all_reduce_alg
=
False
,
loss_scale
=
True
,
dynamic_loss_scale
=
True
,
enable_xla
=
True
,
force_v2_in_keras_compile
=
True
)
...
...
official/recommendation/ncf_input_pipeline.py
View file @
a35e09d2
...
...
@@ -21,7 +21,6 @@ from __future__ import print_function
import
functools
# pylint: disable=g-bad-import-order
import
numpy
as
np
import
tensorflow.compat.v2
as
tf
# pylint: enable=g-bad-import-order
...
...
@@ -42,6 +41,9 @@ def create_dataset_from_tf_record_files(input_file_pattern,
def
make_dataset
(
files_dataset
,
shard_index
):
"""Returns dataset for sharded tf record files."""
if
pre_batch_size
!=
batch_size
:
raise
ValueError
(
"Pre-batch ({}) size is not equal to batch "
"size ({})"
.
format
(
pre_batch_size
,
batch_size
))
files_dataset
=
files_dataset
.
shard
(
NUM_SHARDS
,
shard_index
)
dataset
=
files_dataset
.
interleave
(
tf
.
data
.
TFRecordDataset
)
decode_fn
=
functools
.
partial
(
...
...
@@ -50,8 +52,6 @@ def create_dataset_from_tf_record_files(input_file_pattern,
is_training
=
is_training
)
dataset
=
dataset
.
map
(
decode_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
unbatch
())
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
return
dataset
dataset
=
tf
.
data
.
Dataset
.
range
(
NUM_SHARDS
)
...
...
official/recommendation/ncf_keras_benchmark.py
View file @
a35e09d2
...
...
@@ -117,7 +117,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
"""
self
.
_run_and_report_benchmark
(
hr_at_10_min
=
0.61
)
def
_run_and_report_benchmark
(
self
,
hr_at_10_min
=
0.630
,
hr_at_10_max
=
0.64
0
):
def
_run_and_report_benchmark
(
self
,
hr_at_10_min
=
0.630
,
hr_at_10_max
=
0.64
5
):
"""Run test and report results.
Note: Target is 0.635, but some runs are below that level. Until we have
...
...
@@ -263,6 +263,15 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
train_epochs
=
7
self
.
_run_and_report_benchmark_mlperf_like
()
def
benchmark_1_gpu_ctl_fp16_mlperf_like
(
self
):
"""1 GPU using CTL."""
self
.
_setup
()
FLAGS
.
keras_use_ctl
=
True
FLAGS
.
train_epochs
=
7
FLAGS
.
dtype
=
'fp16'
FLAGS
.
loss_scale
=
8192
self
.
_run_and_report_benchmark_mlperf_like
()
def
benchmark_1_gpu_ctl_run_eagerly_mlperf_like
(
self
):
"""1 GPU using CTL with eager and distribution strategy."""
self
.
_setup
()
...
...
@@ -279,6 +288,16 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
train_epochs
=
7
self
.
_run_and_report_benchmark_mlperf_like
()
def
benchmark_xla_1_gpu_ctl_fp16_mlperf_like
(
self
):
"""1 GPU using CTL with XLA."""
self
.
_setup
()
FLAGS
.
keras_use_ctl
=
True
FLAGS
.
enable_xla
=
True
FLAGS
.
train_epochs
=
7
FLAGS
.
dtype
=
'fp16'
FLAGS
.
loss_scale
=
8192
self
.
_run_and_report_benchmark_mlperf_like
()
def
benchmark_8_gpu_mlperf_like
(
self
):
"""8 GPU using keras fit/compile."""
self
.
_setup
()
...
...
official/recommendation/ncf_keras_main.py
View file @
a35e09d2
...
...
@@ -42,6 +42,7 @@ from official.utils.logs import mlperf_helper
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
tpu_lib
FLAGS
=
flags
.
FLAGS
...
...
@@ -267,6 +268,12 @@ def run_ncf(_):
beta_1
=
params
[
"beta1"
],
beta_2
=
params
[
"beta2"
],
epsilon
=
params
[
"epsilon"
])
if
FLAGS
.
dtype
==
"fp16"
:
optimizer
=
\
tf
.
compat
.
v1
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
optimizer
,
loss_scale
=
flags_core
.
get_loss_scale
(
FLAGS
,
default_for_fp16
=
"dynamic"
))
if
params
[
"keras_use_ctl"
]:
train_loss
,
eval_results
=
run_ncf_custom_training
(
...
...
@@ -371,8 +378,12 @@ def run_ncf_custom_training(params,
softmax_logits
,
sample_weight
=
features
[
rconst
.
VALID_POINT_MASK
])
loss
*=
(
1.0
/
params
[
"batch_size"
])
if
FLAGS
.
dtype
==
"fp16"
:
loss
=
optimizer
.
get_scaled_loss
(
loss
)
grads
=
tape
.
gradient
(
loss
,
keras_model
.
trainable_variables
)
if
FLAGS
.
dtype
==
"fp16"
:
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
# Converting gradients to dense form helps in perf on GPU for NCF
grads
=
neumf_model
.
sparse_to_dense_grads
(
list
(
zip
(
grads
,
keras_model
.
trainable_variables
)))
...
...
official/resnet/ctl/ctl_common.py
View file @
a35e09d2
...
...
@@ -27,3 +27,6 @@ def define_ctl_flags():
flags
.
DEFINE_boolean
(
name
=
'use_tf_function'
,
default
=
True
,
help
=
'Wrap the train and test step inside a '
'tf.function.'
)
flags
.
DEFINE_boolean
(
name
=
'single_l2_loss_op'
,
default
=
False
,
help
=
'Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.'
)
official/resnet/ctl/ctl_imagenet_benchmark.py
View file @
a35e09d2
...
...
@@ -22,7 +22,7 @@ import time
from
absl
import
flags
import
tensorflow
as
tf
from
official.
resnet.keras
import
keras_
common
from
official.
vision.image_classification
import
common
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.resnet.ctl
import
ctl_common
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
...
...
@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
flag_methods
=
[
ctl_common
.
define_ctl_flags
,
keras_
common
.
define_keras_flags
common
.
define_keras_flags
]
self
.
data_dir
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
...
...
@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
flag_methods
=
[
ctl_common
.
define_ctl_flags
,
keras_
common
.
define_keras_flags
common
.
define_keras_flags
]
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
...
...
@@ -215,6 +215,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_eager'
)
FLAGS
.
batch_size
=
64
FLAGS
.
use_tf_function
=
False
FLAGS
.
single_l2_loss_op
=
True
self
.
_run_and_report_benchmark
()
def
benchmark_8_gpu
(
self
):
...
...
official/resnet/ctl/ctl_imagenet_main.py
View file @
a35e09d2
...
...
@@ -24,10 +24,10 @@ from absl import logging
import
tensorflow
as
tf
from
official.resnet.ctl
import
ctl_common
from
official.
resnet.keras
import
imagenet_preprocessing
from
official.
resnet.keras
import
keras_
common
from
official.
resnet.keras
import
keras
_imagenet_main
from
official.
resnet.keras
import
resnet_model
from
official.
vision.image_classification
import
imagenet_preprocessing
from
official.
vision.image_classification
import
common
from
official.
vision.image_classification
import
resnet
_imagenet_main
from
official.
vision.image_classification
import
resnet_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
...
...
@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets."""
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
flags_obj
.
use_synthetic_data
:
input_fn
=
keras_
common
.
get_synth_input_fn
(
input_fn
=
common
.
get_synth_input_fn
(
height
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_preprocessing
.
NUM_CHANNELS
,
...
...
@@ -137,6 +137,10 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
keras_utils
.
set_session_config
(
enable_eager
=
flags_obj
.
enable_eager
,
enable_xla
=
flags_obj
.
enable_xla
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
# TODO(anj-s): Set data_format without using Keras.
...
...
@@ -163,10 +167,11 @@ def run(flags_obj):
with
strategy_scope
:
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_preprocessing
.
NUM_CLASSES
,
dtype
=
dtype
,
batch_size
=
flags_obj
.
batch_size
)
dtype
=
dtype
,
batch_size
=
flags_obj
.
batch_size
,
use_l2_regularizer
=
not
flags_obj
.
single_l2_loss_op
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
keras_
common
.
BASE_LEARNING_RATE
,
momentum
=
0.9
,
learning_rate
=
common
.
BASE_LEARNING_RATE
,
momentum
=
0.9
,
nesterov
=
True
)
training_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
...
...
@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'test_accuracy'
,
dtype
=
tf
.
float32
)
trainable_variables
=
model
.
trainable_variables
def
train_step
(
train_ds_inputs
):
"""Training StepFn."""
def
step_fn
(
inputs
):
...
...
@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
logits
)
loss1
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
flags_obj
.
batch_size
)
loss2
=
(
tf
.
reduce_sum
(
model
.
losses
)
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
)
loss
=
loss1
+
loss2
grads
=
tape
.
gradient
(
loss
,
model
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
model
.
trainable_variables
))
loss
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
flags_obj
.
batch_size
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
if
flags_obj
.
single_l2_loss_op
:
filtered_variables
=
[
tf
.
reshape
(
v
,
(
-
1
,))
for
v
in
trainable_variables
if
'bn'
not
in
v
.
name
]
l2_loss
=
resnet_model
.
L2_WEIGHT_DECAY
*
2
*
tf
.
nn
.
l2_loss
(
tf
.
concat
(
filtered_variables
,
axis
=
0
))
loss
+=
(
l2_loss
/
num_replicas
)
else
:
loss
+=
(
tf
.
reduce_sum
(
model
.
losses
)
/
num_replicas
)
grads
=
tape
.
gradient
(
loss
,
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
trainable_variables
))
training_accuracy
.
update_state
(
labels
,
logits
)
return
loss
...
...
@@ -232,7 +248,7 @@ def run(flags_obj):
training_accuracy
.
reset_states
()
for
step
in
range
(
train_steps
):
optimizer
.
lr
=
keras
_imagenet_main
.
learning_rate_schedule
(
optimizer
.
lr
=
resnet
_imagenet_main
.
learning_rate_schedule
(
epoch
,
step
,
train_steps
,
flags_obj
.
batch_size
)
time_callback
.
on_batch_begin
(
step
+
epoch
*
train_steps
)
...
...
@@ -281,7 +297,7 @@ def main(_):
if
__name__
==
'__main__'
:
logging
.
set_verbosity
(
logging
.
INFO
)
keras_
common
.
define_keras_flags
()
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
flags
.
adopt_module_key_flags
(
keras_common
)
flags
.
adopt_module_key_flags
(
ctl_common
)
...
...
official/resnet/ctl/ctl_imagenet_test.py
View file @
a35e09d2
...
...
@@ -25,8 +25,8 @@ from tensorflow.python.eager import context
from
tensorflow.python.platform
import
googletest
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.
resnet.keras
import
imagenet_preprocessing
from
official.
resnet.keras
import
keras_
common
from
official.
vision.image_classification
import
imagenet_preprocessing
from
official.
vision.image_classification
import
common
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
...
...
@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase):
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
CtlImagenetTest
,
cls
).
setUpClass
()
keras_
common
.
define_keras_flags
()
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
def
setUp
(
self
):
...
...
official/resnet/keras/__init__.py
deleted
100644 → 0
View file @
d5722dcd
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared Keras ResNet modules into this module.
The TensorFlow official Keras models are moved under
official/vision/image_classification
In order to be backward compatible with models that directly import its modules,
we import the Keras ResNet modules under official.resnet.keras.
New TF models should not depend on modules directly under this path.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
official.vision.image_classification
import
cifar_preprocessing
from
official.vision.image_classification
import
common
as
keras_common
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_cifar_main
as
keras_cifar_main
from
official.vision.image_classification
import
resnet_cifar_model
from
official.vision.image_classification
import
resnet_imagenet_main
as
keras_imagenet_main
from
official.vision.image_classification
import
resnet_model
del
absolute_import
del
division
del
print_function
official/staging/shakespeare/shakespeare_benchmark.py
View file @
a35e09d2
...
...
@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
flag_methods
=
[
shakespeare_main
.
define_flags
])
def
_run_and_report_benchmark
(
self
,
top_1_train_min
=
0.9
23
,
top_1_train_max
=
0.9
3
,
top_1_train_min
=
0.9
1
,
top_1_train_max
=
0.9
4
,
warmup
=
1
,
log_steps
=
100
):
"""Report benchmark results by writing to local protobuf file.
...
...
official/transformer/model/beam_search.py
View file @
a35e09d2
...
...
@@ -79,8 +79,41 @@ class _StateKeys(object):
class
SequenceBeamSearch
(
object
):
"""Implementation of beam search loop."""
def
__init__
(
self
,
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
=
tf
.
float32
):
def
__init__
(
self
,
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
=
tf
.
float32
):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and the updated cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
updated cache -> A nested dictionary with the same structure as the
input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
batch_size: An integer, the decode batch size.
beam_size: An integer, number of beams for beam search.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum number of steps to decode
a sequence.
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
vocab_size
=
vocab_size
self
.
batch_size
=
batch_size
...
...
@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
self
.
alpha
=
alpha
self
.
max_decode_length
=
max_decode_length
self
.
eos_id
=
eos_id
self
.
padded_decode
=
padded_decode
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
def
search
(
self
,
initial_ids
,
initial_cache
):
...
...
@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
# Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq
=
_expand_to_beam_size
(
initial_ids
,
self
.
beam_size
)
alive_seq
=
tf
.
expand_dims
(
alive_seq
,
axis
=
2
)
if
self
.
padded_decode
:
alive_seq
=
tf
.
tile
(
alive_seq
,
[
1
,
1
,
self
.
max_decode_length
+
1
])
# Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0
...
...
@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
# 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations.
state_shape_invariants
=
{
_StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
_StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_get_shape_keep_last_dim
,
alive_cache
),
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
])
}
if
self
.
padded_decode
:
state_shape_invariants
=
{
_StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
_StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
(
[
self
.
batch_size
,
self
.
beam_size
,
self
.
max_decode_length
+
1
]),
_StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_get_shape
,
alive_cache
),
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
(
[
self
.
batch_size
,
self
.
beam_size
,
self
.
max_decode_length
+
1
]),
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
])
}
else
:
state_shape_invariants
=
{
_StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
_StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_get_shape_keep_last_dim
,
alive_cache
),
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
])
}
return
state
,
state_shape_invariants
...
...
@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
# Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time.
flat_ids
=
_flatten_beam_dim
(
alive_seq
)
# [batch_size * beam_size]
if
self
.
padded_decode
:
flat_ids
=
tf
.
reshape
(
tf
.
slice
(
alive_seq
,
[
0
,
0
,
i
],
[
self
.
batch_size
,
self
.
beam_size
,
1
]),
[
self
.
batch_size
*
self
.
beam_size
,
-
1
])
else
:
flat_ids
=
_flatten_beam_dim
(
alive_seq
)
# [batch_size * beam_size]
flat_cache
=
nest
.
map_structure
(
_flatten_beam_dim
,
alive_cache
)
flat_logits
,
flat_cache
=
self
.
symbols_to_logits_fn
(
flat_ids
,
i
,
flat_cache
)
...
...
@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
# Append the most probable IDs to the topk sequences
topk_ids
=
topk_indices
%
self
.
vocab_size
topk_ids
=
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)
topk_seq
=
tf
.
concat
([
topk_seq
,
topk_ids
],
axis
=
2
)
if
self
.
padded_decode
:
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
2
,
0
,
1
])
topk_seq
=
tf
.
tensor_scatter_update
(
topk_seq
,
[
i
+
1
],
topk_ids
)
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
1
,
2
,
0
])
else
:
topk_ids
=
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)
topk_seq
=
tf
.
concat
([
topk_seq
,
topk_ids
],
axis
=
2
)
return
topk_seq
,
topk_log_probs
,
new_cache
def
_get_new_alive_state
(
self
,
new_seq
,
new_log_probs
,
new_cache
):
...
...
@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
# First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1]
finished_seq
=
tf
.
concat
(
[
finished_seq
,
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
,
1
],
tf
.
int32
)],
axis
=
2
)
if
not
self
.
padded_decode
:
finished_seq
=
tf
.
concat
([
finished_seq
,
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
,
1
],
tf
.
int32
)
],
axis
=
2
)
# Calculate new seq scores from log probabilities.
length_norm
=
_length_normalization
(
self
.
alpha
,
i
+
1
,
dtype
=
self
.
dtype
)
...
...
@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
):
alpha
,
max_decode_length
,
eos_id
,
padded_decode
=
False
):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
index -> [] (scalar)
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
The function must return logits and new cache.
logits -> [batch * beam_size, vocab_size]
new cache -> same shape/structure as inputted cache
initial_ids: Starting ids for each batch item.
int32 tensor with shape [batch_size]
initial_cache: dict containing starting decoder variables information
vocab_size: int size of tokens
beam_size: int number of beams
alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size
=
tf
.
shape
(
initial_ids
)[
0
]
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
)
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
...
@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
return
tf
.
TensorShape
(
shape_list
)
def
_get_shape
(
tensor
):
"""Return the shape of the input tensor."""
return
tf
.
TensorShape
(
_shape_list
(
tensor
))
def
_flatten_beam_dim
(
tensor
):
"""Reshapes first two dimensions in to single dimension.
...
...
official/transformer/transformer_main.py
View file @
a35e09d2
...
...
@@ -32,6 +32,7 @@ from absl import flags
import
tensorflow
as
tf
# pylint: enable=g-bad-import-order
from
official.r1.utils
import
export
from
official.transformer
import
compute_bleu
from
official.transformer
import
translate
from
official.transformer.model
import
model_params
...
...
@@ -41,7 +42,6 @@ from official.transformer.utils import metrics
from
official.transformer.utils
import
schedule
from
official.transformer.utils
import
tokenizer
from
official.utils.accelerator
import
tpu
as
tpu_util
from
official.utils.export
import
export
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
logger
...
...
@@ -56,7 +56,7 @@ PARAMS_MAP = {
DEFAULT_TRAIN_EPOCHS
=
10
INF
=
int
(
1e9
)
INF
=
1000000000
#
1e9
BLEU_DIR
=
"bleu"
# Dictionary containing tensors that are logged by the logging hooks. Each item
...
...
official/transformer/v2/attention_layer.py
View file @
a35e09d2
...
...
@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
x
=
tf
.
transpose
(
x
,
[
0
,
2
,
1
,
3
])
# --> [batch, length, num_heads, depth]
return
tf
.
reshape
(
x
,
[
batch_size
,
length
,
self
.
hidden_size
])
def
call
(
self
,
x
,
y
,
bias
,
training
,
cache
=
None
):
def
call
(
self
,
x
,
y
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
"""Apply attention mechanism to x and y.
Args:
x: a tensor with shape [batch_size, length_x, hidden_size]
y: a tensor with shape [batch_size, length_y, hidden_size]
bias: attention bias that will be added to the result of the dot product.
training: boolean, whether in training mode or not.
cache: (Used during prediction) dictionary with tensors containing results
of previous attentions. The dictionary must have the items:
x: A tensor with shape [batch_size, length_x, hidden_size].
y: A tensor with shape [batch_size, length_y, hidden_size].
bias: A bool, the attention bias that will be added to the result of the
dot product.
training: A bool, whether in training mode or not.
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns:
Attention layer output with shape [batch_size, length_x, hidden_size]
"""
# Linearly project the query
(q)
, key
(k)
and value
(v)
using different
#
learned
projections. This is in preparation of splitting them into
#
multiple
heads. Multi-head attention uses multiple queries, keys, and
#
values
rather than regular attention (which uses a single q, k, v).
q
=
self
.
q_dense_layer
(
x
)
k
=
self
.
k_dense_layer
(
y
)
v
=
self
.
v_dense_layer
(
y
)
# Linearly project the query, key and value using different
learned
# projections. This is in preparation of splitting them into
multiple
# heads. Multi-head attention uses multiple queries, keys, and
values
# rather than regular attention (which uses a single q
uery
, k
ey
, v
alue
).
q
uery
=
self
.
q_dense_layer
(
x
)
k
ey
=
self
.
k_dense_layer
(
y
)
v
alue
=
self
.
v_dense_layer
(
y
)
if
cache
is
not
None
:
# Combine cached keys and values with new keys and values.
k
=
tf
.
concat
([
tf
.
cast
(
cache
[
"k"
],
k
.
dtype
),
k
],
axis
=
1
)
v
=
tf
.
concat
([
tf
.
cast
(
cache
[
"v"
],
k
.
dtype
),
v
],
axis
=
1
)
if
decode_loop_step
is
not
None
:
cache_k_shape
=
cache
[
"k"
].
shape
.
as_list
()
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
cache_k_shape
[
1
],
dtype
=
key
.
dtype
),
[
1
,
cache_k_shape
[
1
],
1
])
key
=
cache
[
"k"
]
+
key
*
indices
cache_v_shape
=
cache
[
"v"
].
shape
.
as_list
()
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
cache_v_shape
[
1
],
dtype
=
value
.
dtype
),
[
1
,
cache_v_shape
[
1
],
1
])
value
=
cache
[
"v"
]
+
value
*
indices
else
:
key
=
tf
.
concat
([
tf
.
cast
(
cache
[
"k"
],
key
.
dtype
),
key
],
axis
=
1
)
value
=
tf
.
concat
([
tf
.
cast
(
cache
[
"v"
],
value
.
dtype
),
value
],
axis
=
1
)
# Update cache
cache
[
"k"
]
=
k
cache
[
"v"
]
=
v
cache
[
"k"
]
=
k
ey
cache
[
"v"
]
=
v
alue
# Split q, k, v into heads.
q
=
self
.
split_heads
(
q
)
k
=
self
.
split_heads
(
k
)
v
=
self
.
split_heads
(
v
)
# Split q
uery
, k
ey
, v
alue
into heads.
q
uery
=
self
.
split_heads
(
q
uery
)
k
ey
=
self
.
split_heads
(
k
ey
)
v
alue
=
self
.
split_heads
(
v
alue
)
# Scale q to prevent the dot product between q and k from growing too large.
# Scale query to prevent the dot product between query and key from growing
# too large.
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
q
*=
depth
**
-
0.5
q
uery
*=
depth
**
-
0.5
# Calculate dot product attention
logits
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
q
uery
,
k
ey
,
transpose_b
=
True
)
logits
+=
bias
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
...
...
@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
weights
=
tf
.
nn
.
softmax
(
logits
,
name
=
"attention_weights"
)
if
training
:
weights
=
tf
.
nn
.
dropout
(
weights
,
rate
=
self
.
attention_dropout
)
attention_output
=
tf
.
matmul
(
weights
,
v
)
attention_output
=
tf
.
matmul
(
weights
,
v
alue
)
# Recombine heads --> [batch_size, length, hidden_size]
attention_output
=
self
.
combine_heads
(
attention_output
)
...
...
@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
class
SelfAttention
(
Attention
):
"""Multiheaded self-attention layer."""
def
call
(
self
,
x
,
bias
,
training
,
cache
=
None
):
return
super
(
SelfAttention
,
self
).
call
(
x
,
x
,
bias
,
training
,
cache
)
def
call
(
self
,
x
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
return
super
(
SelfAttention
,
self
).
call
(
x
,
x
,
bias
,
training
,
cache
,
decode_loop_step
)
official/transformer/v2/beam_search.py
View file @
a35e09d2
...
...
@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
return
finished_seq
,
finished_scores
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
=
"float32"
):
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
=
False
,
dtype
=
"float32"
):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
index -> [] (scalar)
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
The function must return logits and new cache.
logits -> [batch * beam_size, vocab_size]
new cache -> same shape/structure as inputted cache
initial_ids: Starting ids for each batch item.
int32 tensor with shape [batch_size]
initial_cache: dict containing starting decoder variables information
vocab_size: int size of tokens
beam_size: int number of beams
alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished,
dtype: The dtype to use.
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of tokens.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size
=
tf
.
shape
(
initial_ids
)[
0
]
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
if
misc
.
is_v2
():
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
)
padded_decode
,
dtype
)
else
:
sbs
=
v1
.
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
)
padded_decode
,
dtype
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
...
official/transformer/v2/embedding_layer.py
View file @
a35e09d2
...
...
@@ -24,24 +24,14 @@ import tensorflow as tf
class
EmbeddingSharedWeights
(
tf
.
keras
.
layers
.
Layer
):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def
__init__
(
self
,
vocab_size
,
hidden_size
,
dtype
=
None
):
def
__init__
(
self
,
vocab_size
,
hidden_size
):
"""Specify characteristic parameters of embedding layer.
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
"""
if
dtype
==
tf
.
float16
:
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"float16_with_float32_vars"
)
super
(
EmbeddingSharedWeights
,
self
).
__init__
(
dtype
=
dtype
)
super
(
EmbeddingSharedWeights
,
self
).
__init__
()
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
...
...
@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self
.
shared_weights
=
self
.
add_weight
(
"weights"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
dtype
=
"float32"
,
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
super
(
EmbeddingSharedWeights
,
self
).
build
(
input_shape
)
...
...
official/transformer/v2/misc.py
View file @
a35e09d2
...
...
@@ -192,6 +192,29 @@ def define_transformer_flags():
help
=
flags_core
.
help_wrap
(
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'
))
flags
.
DEFINE_integer
(
name
=
'decode_batch_size'
,
default
=
32
,
help
=
flags_core
.
help_wrap
(
'Global batch size used for Transformer autoregressive decoding on '
'TPU.'
))
flags
.
DEFINE_integer
(
name
=
'decode_max_length'
,
default
=
97
,
help
=
flags_core
.
help_wrap
(
'Max sequence length of the decode/eval data. This is used by '
'Transformer autoregressive decoding on TPU to have minimum '
'paddings.'
))
flags
.
DEFINE_bool
(
name
=
'padded_decode'
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
'Whether the autoregressive decoding runs with input data padded to '
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
'set due the static shape requirement. Although CPU/GPU could also '
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'
))
flags_core
.
set_defaults
(
data_dir
=
'/tmp/translate_ende'
,
model_dir
=
'/tmp/transformer_model'
,
...
...
official/transformer/v2/transformer.py
View file @
a35e09d2
...
...
@@ -49,8 +49,10 @@ def create_model(params, is_train):
label_smoothing
=
params
[
"label_smoothing"
]
if
params
[
"enable_metrics_in_training"
]:
logits
=
metrics
.
MetricLayer
(
vocab_size
)([
logits
,
targets
])
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
)(
logits
)
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
,
dtype
=
tf
.
float32
)(
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss
=
metrics
.
transformer_loss
(
logits
,
targets
,
label_smoothing
,
vocab_size
)
model
.
add_loss
(
loss
)
...
...
@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
super
(
Transformer
,
self
).
__init__
(
name
=
name
)
self
.
params
=
params
self
.
embedding_softmax_layer
=
embedding_layer
.
EmbeddingSharedWeights
(
params
[
"vocab_size"
],
params
[
"hidden_size"
]
,
dtype
=
params
[
"dtype"
]
)
params
[
"vocab_size"
],
params
[
"hidden_size"
])
self
.
encoder_stack
=
EncoderStack
(
params
)
self
.
decoder_stack
=
DecoderStack
(
params
)
...
...
@@ -112,11 +114,22 @@ class Transformer(tf.keras.Model):
outputs: [batch_size, decoded length]
scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
if
len
(
inputs
)
==
2
:
inputs
,
targets
=
inputs
[
0
],
inputs
[
1
]
else
:
inputs
,
targets
=
inputs
[
0
],
None
if
self
.
params
[
"padded_decode"
]:
if
not
self
.
params
[
"num_replicas"
]:
raise
NotImplementedError
(
"Padded decoding on CPU/GPUs is not supported."
)
decode_batch_size
=
int
(
self
.
params
[
"decode_batch_size"
]
/
self
.
params
[
"num_replicas"
])
inputs
=
tf
.
reshape
(
inputs
,
[
decode_batch_size
,
self
.
params
[
"decode_max_length"
]])
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
...
...
@@ -225,13 +238,14 @@ class Transformer(tf.keras.Model):
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
max_decode_length
,
dtype
=
self
.
params
[
"dtype"
])
# TODO(b/139770046): Refactor code with better naming of i.
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences. int tensor with shape [batch_size *
beam_size, i + 1]
i: Loop index
beam_size, i + 1]
.
i: Loop index
.
cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
...
...
@@ -245,16 +259,29 @@ class Transformer(tf.keras.Model):
# Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input
=
self
.
embedding_softmax_layer
(
decoder_input
)
decoder_input
+=
timing_signal
[
i
:
i
+
1
]
self_attention_bias
=
decoder_self_attention_bias
[:,
:,
i
:
i
+
1
,
:
i
+
1
]
if
self
.
params
[
"padded_decode"
]:
timing_signal_shape
=
timing_signal
.
shape
.
as_list
()
decoder_input
+=
tf
.
slice
(
timing_signal
,
[
i
,
0
],
[
1
,
timing_signal_shape
[
1
]])
bias_shape
=
decoder_self_attention_bias
.
shape
.
as_list
()
self_attention_bias
=
tf
.
slice
(
decoder_self_attention_bias
,
[
0
,
0
,
i
,
0
],
[
bias_shape
[
0
],
bias_shape
[
1
],
1
,
bias_shape
[
3
]])
else
:
decoder_input
+=
timing_signal
[
i
:
i
+
1
]
self_attention_bias
=
decoder_self_attention_bias
[:,
:,
i
:
i
+
1
,
:
i
+
1
]
decoder_outputs
=
self
.
decoder_stack
(
decoder_input
,
cache
.
get
(
"encoder_outputs"
),
self_attention_bias
,
cache
.
get
(
"encoder_decoder_attention_bias"
),
training
=
training
,
cache
=
cache
)
cache
=
cache
,
decode_loop_step
=
i
if
self
.
params
[
"padded_decode"
]
else
None
)
logits
=
self
.
embedding_softmax_layer
(
decoder_outputs
,
mode
=
"linear"
)
logits
=
tf
.
squeeze
(
logits
,
axis
=
[
1
])
return
logits
,
cache
...
...
@@ -263,8 +290,12 @@ class Transformer(tf.keras.Model):
def
predict
(
self
,
encoder_outputs
,
encoder_decoder_attention_bias
,
training
):
"""Return predicted sequence."""
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
if
self
.
params
[
"padded_decode"
]:
batch_size
=
encoder_outputs
.
shape
.
as_list
()[
0
]
input_length
=
encoder_outputs
.
shape
.
as_list
()[
1
]
else
:
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
max_decode_length
=
input_length
+
self
.
params
[
"extra_decode_length"
]
encoder_decoder_attention_bias
=
tf
.
cast
(
encoder_decoder_attention_bias
,
self
.
params
[
"dtype"
])
...
...
@@ -277,12 +308,20 @@ class Transformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length
=
(
max_decode_length
if
self
.
params
[
"padded_decode"
]
else
0
)
cache
=
{
"layer_%d"
%
layer
:
{
"k"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]],
dtype
=
self
.
params
[
"dtype"
]),
"v"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]],
dtype
=
self
.
params
[
"dtype"
])
"k"
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
self
.
params
[
"hidden_size"
]
],
dtype
=
self
.
params
[
"dtype"
]),
"v"
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
self
.
params
[
"hidden_size"
]
],
dtype
=
self
.
params
[
"dtype"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
}
# pylint: enable=g-complex-comprehension
...
...
@@ -301,6 +340,7 @@ class Transformer(tf.keras.Model):
alpha
=
self
.
params
[
"alpha"
],
max_decode_length
=
max_decode_length
,
eos_id
=
EOS_ID
,
padded_decode
=
self
.
params
[
"padded_decode"
],
dtype
=
self
.
params
[
"dtype"
])
# Get the top sequence for each batch element
...
...
@@ -505,22 +545,28 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_self_attention_bias
,
attention_bias
,
training
,
cache
=
None
):
cache
=
None
,
decode_loop_step
=
None
):
"""Return the output of the decoder layer stacks.
Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size]
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size]
decoder_self_attention_bias: bias for decoder self-attention layer. [1, 1,
target_len, target_length]
attention_bias: bias for encoder-decoder attention layer. [batch_size, 1,
1, input_length]
training: boolean, whether in training mode or not.
decoder_inputs: A tensor with shape
[batch_size, target_length, hidden_size].
encoder_outputs: A tensor with shape
[batch_size, input_length, hidden_size]
decoder_self_attention_bias: A tensor with shape
[1, 1, target_len, target_length], the bias for decoder self-attention
layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
the bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]},
{layer_n: {"k":
A
tensor with shape [batch_size, i, key_channels],
"v":
A
tensor with shape [batch_size, i, value_channels]},
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns:
Output of decoder layer stack.
...
...
@@ -540,7 +586,8 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_inputs
,
decoder_self_attention_bias
,
training
=
training
,
cache
=
layer_cache
)
cache
=
layer_cache
,
decode_loop_step
=
decode_loop_step
)
with
tf
.
name_scope
(
"encdec_attention"
):
decoder_inputs
=
enc_dec_attention_layer
(
decoder_inputs
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment