Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f8ec01ae
Unverified
Commit
f8ec01ae
authored
Jun 19, 2019
by
Reed
Committed by
GitHub
Jun 19, 2019
Browse files
Add mixed precision support to Transformer (#7011)
parent
269581dc
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
102 additions
and
16 deletions
+102
-16
official/transformer/model/model_utils.py
official/transformer/model/model_utils.py
+18
-7
official/transformer/v2/attention_layer.py
official/transformer/v2/attention_layer.py
+21
-3
official/transformer/v2/misc.py
official/transformer/v2/misc.py
+2
-1
official/transformer/v2/transformer.py
official/transformer/v2/transformer.py
+26
-5
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+10
-0
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+24
-0
official/transformer/v2/transformer_test.py
official/transformer/v2/transformer_test.py
+1
-0
No files found.
official/transformer/model/model_utils.py
View file @
f8ec01ae
...
...
@@ -20,9 +20,13 @@ from __future__ import print_function
import
math
import
numpy
as
np
import
tensorflow
as
tf
_NEG_INF
=
-
1e9
# Very low numbers to represent -infinity. We do not actually use -Inf, since we
# want to be able to multiply these values by zero to get zero. (-Inf * 0 = NaN)
_NEG_INF_FP32
=
-
1e9
_NEG_INF_FP16
=
np
.
finfo
(
np
.
float16
).
min
def
get_position_encoding
(
...
...
@@ -42,6 +46,9 @@ def get_position_encoding(
Returns:
Tensor with shape [length, hidden_size]
"""
# We compute the positional encoding in float32 even if the model uses
# float16, as many of the ops used, like log and exp, are numerically unstable
# in float16.
position
=
tf
.
cast
(
tf
.
range
(
length
),
tf
.
float32
)
num_timescales
=
hidden_size
//
2
log_timescale_increment
=
(
...
...
@@ -54,7 +61,7 @@ def get_position_encoding(
return
signal
def
get_decoder_self_attention_bias
(
length
):
def
get_decoder_self_attention_bias
(
length
,
dtype
=
tf
.
float32
):
"""Calculate bias for decoder that maintains model's autoregressive property.
Creates a tensor that masks out locations that correspond to illegal
...
...
@@ -63,30 +70,34 @@ def get_decoder_self_attention_bias(length):
Args:
length: int length of sequences in batch.
dtype: The dtype of the return value.
Returns:
float tensor of shape [1, 1, length, length]
"""
neg_inf
=
_NEG_INF_FP16
if
dtype
==
tf
.
float16
else
_NEG_INF_FP32
with
tf
.
name_scope
(
"decoder_self_attention_bias"
):
valid_locs
=
tf
.
linalg
.
band_part
(
tf
.
ones
([
length
,
length
]),
-
1
,
0
)
valid_locs
=
tf
.
linalg
.
band_part
(
tf
.
ones
([
length
,
length
],
dtype
=
dtype
),
-
1
,
0
)
valid_locs
=
tf
.
reshape
(
valid_locs
,
[
1
,
1
,
length
,
length
])
decoder_bias
=
_NEG_INF
*
(
1.0
-
valid_locs
)
decoder_bias
=
neg_inf
*
(
1.0
-
valid_locs
)
return
decoder_bias
def
get_padding
(
x
,
padding_value
=
0
):
def
get_padding
(
x
,
padding_value
=
0
,
dtype
=
tf
.
float32
):
"""Return float tensor representing the padding values in x.
Args:
x: int tensor with any shape
padding_value: int value that
dtype: The dtype of the return value.
Returns:
float tensor with same shape as x containing values 0 or 1.
0 -> non-padding, 1 -> padding
"""
with
tf
.
name_scope
(
"padding"
):
return
tf
.
cast
(
tf
.
equal
(
x
,
padding_value
),
tf
.
float32
)
return
tf
.
cast
(
tf
.
equal
(
x
,
padding_value
),
dtype
)
def
get_padding_bias
(
x
):
...
...
@@ -104,7 +115,7 @@ def get_padding_bias(x):
"""
with
tf
.
name_scope
(
"attention_bias"
):
padding
=
get_padding
(
x
)
attention_bias
=
padding
*
_NEG_INF
attention_bias
=
padding
*
_NEG_INF
_FP32
attention_bias
=
tf
.
expand_dims
(
tf
.
expand_dims
(
attention_bias
,
axis
=
1
),
axis
=
1
)
return
attention_bias
official/transformer/v2/attention_layer.py
View file @
f8ec01ae
...
...
@@ -21,6 +21,24 @@ from __future__ import print_function
import
tensorflow
as
tf
def
_float32_softmax
(
logits
,
name
=
None
):
"""Computes a softmax activation in float32.
When training a model using float16, softmax is still done in float32 for
numeric stability.
Args:
logits: A tensor, with any shape accepted by `tf.nn.softmax`.
Returns:
A tensor with the same dtype as `logits`.
"""
input_dtype
=
logits
.
dtype
logits
=
tf
.
cast
(
logits
,
tf
.
float32
)
output
=
tf
.
nn
.
softmax
(
logits
,
name
=
name
)
return
tf
.
cast
(
output
,
input_dtype
)
class
Attention
(
tf
.
keras
.
layers
.
Layer
):
"""Multi-headed attention layer."""
...
...
@@ -129,8 +147,8 @@ class Attention(tf.keras.layers.Layer):
if
cache
is
not
None
:
# Combine cached keys and values with new keys and values.
k
=
tf
.
concat
([
cache
[
"k"
],
k
],
axis
=
1
)
v
=
tf
.
concat
([
cache
[
"v"
],
v
],
axis
=
1
)
k
=
tf
.
concat
([
tf
.
cast
(
cache
[
"k"
],
k
.
dtype
),
k
],
axis
=
1
)
v
=
tf
.
concat
([
tf
.
cast
(
cache
[
"v"
],
k
.
dtype
),
v
],
axis
=
1
)
# Update cache
cache
[
"k"
]
=
k
...
...
@@ -148,7 +166,7 @@ class Attention(tf.keras.layers.Layer):
# Calculate dot product attention
logits
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
logits
+=
bias
weights
=
tf
.
nn
.
softmax
(
logits
,
name
=
"attention_weights"
)
weights
=
_float32_
softmax
(
logits
,
name
=
"attention_weights"
)
if
training
:
weights
=
tf
.
nn
.
dropout
(
weights
,
rate
=
self
.
attention_dropout
)
attention_output
=
tf
.
matmul
(
weights
,
v
)
...
...
official/transformer/v2/misc.py
View file @
f8ec01ae
...
...
@@ -68,7 +68,8 @@ def define_transformer_flags():
intra_op
=
False
,
synthetic_data
=
True
,
max_train_steps
=
False
,
dtype
=
False
,
dtype
=
True
,
loss_scale
=
True
,
all_reduce_alg
=
True
,
enable_xla
=
True
)
...
...
official/transformer/v2/transformer.py
View file @
f8ec01ae
...
...
@@ -102,6 +102,7 @@ class Transformer(tf.keras.Model):
returns a dictionary {
outputs: [batch_size, decoded length]
scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32.
"""
if
len
(
inputs
)
==
2
:
inputs
,
targets
=
inputs
[
0
],
inputs
[
1
]
...
...
@@ -141,12 +142,15 @@ class Transformer(tf.keras.Model):
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs
=
self
.
embedding_softmax_layer
(
inputs
)
embedded_inputs
=
tf
.
cast
(
embedded_inputs
,
self
.
params
[
"dtype"
])
inputs_padding
=
model_utils
.
get_padding
(
inputs
)
attention_bias
=
tf
.
cast
(
attention_bias
,
self
.
params
[
"dtype"
])
with
tf
.
name_scope
(
"add_pos_encoding"
):
length
=
tf
.
shape
(
embedded_inputs
)[
1
]
pos_encoding
=
model_utils
.
get_position_encoding
(
length
,
self
.
params
[
"hidden_size"
])
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
params
[
"dtype"
])
encoder_inputs
=
embedded_inputs
+
pos_encoding
if
training
:
...
...
@@ -174,21 +178,25 @@ class Transformer(tf.keras.Model):
# Prepare inputs to decoder layers by shifting targets, adding positional
# encoding and applying dropout.
decoder_inputs
=
self
.
embedding_softmax_layer
(
targets
)
decoder_inputs
=
tf
.
cast
(
decoder_inputs
,
self
.
params
[
'dtype'
])
attention_bias
=
tf
.
cast
(
attention_bias
,
self
.
params
[
"dtype"
])
with
tf
.
name_scope
(
"shift_targets"
):
# Shift targets to the right, and remove the last element
decoder_inputs
=
tf
.
pad
(
decoder_inputs
,
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
]])[:,
:
-
1
,
:]
with
tf
.
name_scope
(
"add_pos_encoding"
):
length
=
tf
.
shape
(
decoder_inputs
)[
1
]
decoder_inputs
+
=
model_utils
.
get_position_encoding
(
pos_encoding
=
model_utils
.
get_position_encoding
(
length
,
self
.
params
[
"hidden_size"
])
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
params
[
"dtype"
])
decoder_inputs
+=
pos_encoding
if
training
:
decoder_inputs
=
tf
.
nn
.
dropout
(
decoder_inputs
,
rate
=
self
.
params
[
"layer_postprocess_dropout"
])
# Run values
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
length
)
length
,
dtype
=
self
.
params
[
'dtype'
]
)
outputs
=
self
.
decoder_stack
(
decoder_inputs
,
encoder_outputs
,
...
...
@@ -196,6 +204,7 @@ class Transformer(tf.keras.Model):
attention_bias
,
training
=
training
)
logits
=
self
.
embedding_softmax_layer
(
outputs
,
mode
=
"linear"
)
logits
=
tf
.
cast
(
logits
,
tf
.
float32
)
return
logits
def
_get_symbols_to_logits_fn
(
self
,
max_decode_length
,
training
):
...
...
@@ -244,6 +253,9 @@ class Transformer(tf.keras.Model):
def
predict
(
self
,
encoder_outputs
,
encoder_decoder_attention_bias
,
training
):
"""Return predicted sequence."""
# Currently, we always do prediction in float32.
# TODO(reedwm): Add float16 support.
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
tf
.
float32
)
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
max_decode_length
=
input_length
+
self
.
params
[
"extra_decode_length"
]
...
...
@@ -295,16 +307,22 @@ class LayerNormalization(tf.keras.layers.Layer):
def
build
(
self
,
input_shape
):
"""Builds the layer."""
# Passing experimental_autocast=False causes these variables to not be
# automatically casted to fp16 when mixed precision is used. Since we use
# float32 in call() for numeric stability, we do not want variables to be
# casted to fp16.
self
.
scale
=
self
.
add_weight
(
"layer_norm_scale"
,
shape
=
[
self
.
hidden_size
],
dtype
=
"float32"
,
initializer
=
tf
.
ones_initializer
())
initializer
=
tf
.
ones_initializer
(),
experimental_autocast
=
False
)
self
.
bias
=
self
.
add_weight
(
"layer_norm_bias"
,
shape
=
[
self
.
hidden_size
],
dtype
=
"float32"
,
initializer
=
tf
.
zeros_initializer
())
initializer
=
tf
.
zeros_initializer
(),
experimental_autocast
=
False
)
super
(
LayerNormalization
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
...
...
@@ -313,10 +331,13 @@ class LayerNormalization(tf.keras.layers.Layer):
}
def
call
(
self
,
x
,
epsilon
=
1e-6
):
input_dtype
=
x
.
dtype
if
input_dtype
==
tf
.
float16
:
x
=
tf
.
cast
(
x
,
tf
.
float32
)
mean
=
tf
.
reduce_mean
(
x
,
axis
=
[
-
1
],
keepdims
=
True
)
variance
=
tf
.
reduce_mean
(
tf
.
square
(
x
-
mean
),
axis
=
[
-
1
],
keepdims
=
True
)
norm_x
=
(
x
-
mean
)
*
tf
.
math
.
rsqrt
(
variance
+
epsilon
)
return
norm_x
*
self
.
scale
+
self
.
bias
return
tf
.
cast
(
norm_x
*
self
.
scale
+
self
.
bias
,
input_dtype
)
class
PrePostProcessingWrapper
(
tf
.
keras
.
layers
.
Layer
):
...
...
official/transformer/v2/transformer_main.py
View file @
f8ec01ae
...
...
@@ -118,6 +118,7 @@ class TransformerTask(object):
params
[
"use_synthetic_data"
]
=
flags_obj
.
use_synthetic_data
params
[
"batch_size"
]
=
flags_obj
.
batch_size
or
params
[
"default_batch_size"
]
params
[
"repeat_dataset"
]
=
None
params
[
"dtype"
]
=
flags_core
.
get_tf_dtype
(
flags_obj
)
def
train
(
self
):
"""Trains the model."""
...
...
@@ -246,6 +247,10 @@ class TransformerTask(object):
params
[
"optimizer_adam_beta1"
],
params
[
"optimizer_adam_beta2"
],
epsilon
=
params
[
"optimizer_adam_epsilon"
])
if
params
[
"dtype"
]
==
tf
.
float16
:
opt
=
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
opt
,
loss_scale
=
flags_core
.
get_loss_scale
(
self
.
flags_obj
,
default_for_fp16
=
"dynamic"
))
return
opt
...
...
@@ -258,6 +263,11 @@ def _ensure_dir(log_dir):
def
main
(
_
):
flags_obj
=
flags
.
FLAGS
with
logger
.
benchmark_context
(
flags_obj
):
if
flags_core
.
get_tf_dtype
(
flags_obj
)
==
'float16'
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'infer_float32_vars'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
task
=
TransformerTask
(
flags_obj
)
if
flags_obj
.
mode
==
"train"
:
task
.
train
()
...
...
official/transformer/v2/transformer_main_test.py
View file @
f8ec01ae
...
...
@@ -51,12 +51,17 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS
.
batch_size
=
8
FLAGS
.
num_gpus
=
1
FLAGS
.
distribution_strategy
=
"off"
FLAGS
.
dtype
=
"fp32"
self
.
model_dir
=
FLAGS
.
model_dir
self
.
temp_dir
=
temp_dir
self
.
vocab_file
=
os
.
path
.
join
(
temp_dir
,
"vocab"
)
self
.
vocab_size
=
misc
.
get_model_params
(
FLAGS
.
param_set
,
0
)[
"vocab_size"
]
self
.
bleu_source
=
os
.
path
.
join
(
temp_dir
,
"bleu_source"
)
self
.
bleu_ref
=
os
.
path
.
join
(
temp_dir
,
"bleu_ref"
)
self
.
orig_policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
def
tearDown
(
self
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
self
.
orig_policy
)
def
_assert_exists
(
self
,
filepath
):
self
.
assertTrue
(
os
.
path
.
exists
(
filepath
))
...
...
@@ -82,6 +87,17 @@ class TransformerTaskTest(tf.test.TestCase):
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
def
test_train_2_gpu_fp16
(
self
):
FLAGS
.
distribution_strategy
=
"mirrored"
FLAGS
.
num_gpus
=
2
FLAGS
.
param_set
=
"base"
FLAGS
.
dtype
=
"fp16"
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'infer_float32_vars'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
def
_prepare_files_and_flags
(
self
,
*
extra_flags
):
# Make log dir.
if
not
os
.
path
.
exists
(
self
.
temp_dir
):
...
...
@@ -113,6 +129,14 @@ class TransformerTaskTest(tf.test.TestCase):
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
predict
()
def
test_predict_fp16
(
self
):
self
.
_prepare_files_and_flags
(
"--dtype=fp16"
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'infer_float32_vars'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
predict
()
def
test_eval
(
self
):
self
.
_prepare_files_and_flags
()
t
=
tm
.
TransformerTask
(
FLAGS
)
...
...
official/transformer/v2/transformer_test.py
View file @
f8ec01ae
...
...
@@ -37,6 +37,7 @@ class TransformerV2Test(tf.test.TestCase):
params
[
"vocab_size"
]
=
41
params
[
"extra_decode_length"
]
=
2
params
[
"beam_size"
]
=
3
params
[
"dtype"
]
=
tf
.
float32
def
test_create_model_train
(
self
):
model
=
transformer
.
create_model
(
self
.
params
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment