Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cb8ce606
"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "a5cfcb93ffbe52fef49ecdcfd1ce01974799694e"
Commit
cb8ce606
authored
Aug 09, 2019
by
Nimit Nigania
Browse files
Merge remote-tracking branch 'upstream/master'
parents
52372782
62184a96
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
110 additions
and
33 deletions
+110
-33
official/recommendation/ncf_keras_benchmark.py
official/recommendation/ncf_keras_benchmark.py
+16
-1
official/recommendation/ncf_keras_main.py
official/recommendation/ncf_keras_main.py
+4
-2
official/transformer/model/beam_search.py
official/transformer/model/beam_search.py
+48
-15
official/transformer/v2/beam_search.py
official/transformer/v2/beam_search.py
+7
-4
official/transformer/v2/embedding_layer.py
official/transformer/v2/embedding_layer.py
+13
-3
official/transformer/v2/transformer.py
official/transformer/v2/transformer.py
+16
-8
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+6
-0
No files found.
official/recommendation/ncf_keras_benchmark.py
View file @
cb8ce606
...
@@ -185,6 +185,13 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -185,6 +185,13 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
early_stopping
=
True
FLAGS
.
early_stopping
=
True
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
def
benchmark_1_gpu_ctl_run_eagerly_early_stop
(
self
):
self
.
_setup
()
FLAGS
.
keras_use_ctl
=
True
FLAGS
.
early_stopping
=
True
FLAGS
.
run_eagerly
=
True
self
.
_run_and_report_benchmark
()
def
benchmark_xla_1_gpu_ctl_early_stop
(
self
):
def
benchmark_xla_1_gpu_ctl_early_stop
(
self
):
self
.
_setup
()
self
.
_setup
()
FLAGS
.
keras_use_ctl
=
True
FLAGS
.
keras_use_ctl
=
True
...
@@ -207,7 +214,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -207,7 +214,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
#############################################
#############################################
# Tests below with mlperf in the test name are of two types
# Tests below with mlperf in the test name are of two types
:
# 1) 1 GPU tests are based on MLPerf 0.5 and the TensorFlow pulled submission.
# 1) 1 GPU tests are based on MLPerf 0.5 and the TensorFlow pulled submission.
# 2) 8 GPU tests are based on MLPerf 0.5 and use NVIDIA's hyper parameters.
# 2) 8 GPU tests are based on MLPerf 0.5 and use NVIDIA's hyper parameters.
#
#
...
@@ -258,6 +265,14 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -258,6 +265,14 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
train_epochs
=
7
FLAGS
.
train_epochs
=
7
self
.
_run_and_report_benchmark_mlperf_like
()
self
.
_run_and_report_benchmark_mlperf_like
()
def
benchmark_1_gpu_ctl_run_eagerly_mlperf_like
(
self
):
"""1 GPU using CTL with eager and distribution strategy."""
self
.
_setup
()
FLAGS
.
keras_use_ctl
=
True
FLAGS
.
run_eagerly
=
True
FLAGS
.
train_epochs
=
7
self
.
_run_and_report_benchmark
()
def
benchmark_xla_1_gpu_ctl_mlperf_like
(
self
):
def
benchmark_xla_1_gpu_ctl_mlperf_like
(
self
):
"""1 GPU using CTL with XLA."""
"""1 GPU using CTL with XLA."""
self
.
_setup
()
self
.
_setup
()
...
...
official/recommendation/ncf_keras_main.py
View file @
cb8ce606
...
@@ -285,7 +285,6 @@ def run_ncf(_):
...
@@ -285,7 +285,6 @@ def run_ncf(_):
train_input_iterator
=
strategy
.
make_dataset_iterator
(
train_input_dataset
)
train_input_iterator
=
strategy
.
make_dataset_iterator
(
train_input_dataset
)
eval_input_iterator
=
strategy
.
make_dataset_iterator
(
eval_input_dataset
)
eval_input_iterator
=
strategy
.
make_dataset_iterator
(
eval_input_dataset
)
@
tf
.
function
def
train_step
():
def
train_step
():
"""Called once per step to train the model."""
"""Called once per step to train the model."""
def
step_fn
(
features
):
def
step_fn
(
features
):
...
@@ -310,7 +309,6 @@ def run_ncf(_):
...
@@ -310,7 +309,6 @@ def run_ncf(_):
tf
.
distribute
.
ReduceOp
.
SUM
,
per_replica_losses
,
axis
=
None
)
tf
.
distribute
.
ReduceOp
.
SUM
,
per_replica_losses
,
axis
=
None
)
return
mean_loss
return
mean_loss
@
tf
.
function
def
eval_step
():
def
eval_step
():
"""Called once per eval step to compute eval metrics."""
"""Called once per eval step to compute eval metrics."""
def
step_fn
(
features
):
def
step_fn
(
features
):
...
@@ -330,6 +328,10 @@ def run_ncf(_):
...
@@ -330,6 +328,10 @@ def run_ncf(_):
tf
.
distribute
.
ReduceOp
.
SUM
,
per_replica_hr_count
,
axis
=
None
)
tf
.
distribute
.
ReduceOp
.
SUM
,
per_replica_hr_count
,
axis
=
None
)
return
hr_sum
,
hr_count
return
hr_sum
,
hr_count
if
not
FLAGS
.
run_eagerly
:
train_step
=
tf
.
function
(
train_step
)
eval_step
=
tf
.
function
(
eval_step
)
time_callback
.
on_train_begin
()
time_callback
.
on_train_begin
()
for
epoch
in
range
(
FLAGS
.
train_epochs
):
for
epoch
in
range
(
FLAGS
.
train_epochs
):
for
cb
in
callbacks
:
for
cb
in
callbacks
:
...
...
official/transformer/model/beam_search.py
View file @
cb8ce606
...
@@ -18,11 +18,31 @@ Source implementation from Tensor2Tensor:
...
@@ -18,11 +18,31 @@ Source implementation from Tensor2Tensor:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam_search.py
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam_search.py
"""
"""
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.util
import
nest
from
tensorflow.python.util
import
nest
# Default value for INF
INF
=
1.
*
1e7
def
inf
(
dtype
):
"""Returns a value close to infinity, but is still finite in `dtype`.
This is useful to get a very large value that is still zero when multiplied by
zero. The floating-point "Inf" value is NaN when multiplied by zero.
Args:
dtype: A dtype. The returned value will be finite when casted to this dtype.
Returns:
A very large value.
"""
if
dtype
==
"float32"
:
return
1e7
elif
dtype
==
"float16"
:
# Disable no-member lint error, as the linter thinks np.float16 does not
# exist for some reason.
return
np
.
finfo
(
np
.
float16
).
max
# pylint: disable=no-member
else
:
raise
AssertionError
(
'Invalid dtype: %s'
%
dtype
)
class
_StateKeys
(
object
):
class
_StateKeys
(
object
):
...
@@ -60,7 +80,7 @@ class SequenceBeamSearch(object):
...
@@ -60,7 +80,7 @@ class SequenceBeamSearch(object):
"""Implementation of beam search loop."""
"""Implementation of beam search loop."""
def
__init__
(
self
,
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
def
__init__
(
self
,
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
):
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
=
tf
.
float32
):
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -68,6 +88,7 @@ class SequenceBeamSearch(object):
...
@@ -68,6 +88,7 @@ class SequenceBeamSearch(object):
self
.
alpha
=
alpha
self
.
alpha
=
alpha
self
.
max_decode_length
=
max_decode_length
self
.
max_decode_length
=
max_decode_length
self
.
eos_id
=
eos_id
self
.
eos_id
=
eos_id
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
def
search
(
self
,
initial_ids
,
initial_cache
):
def
search
(
self
,
initial_ids
,
initial_cache
):
"""Beam search for sequences with highest scores."""
"""Beam search for sequences with highest scores."""
...
@@ -105,6 +126,14 @@ class SequenceBeamSearch(object):
...
@@ -105,6 +126,14 @@ class SequenceBeamSearch(object):
Returns:
Returns:
state and shape invariant dictionaries with keys from _StateKeys
state and shape invariant dictionaries with keys from _StateKeys
"""
"""
for
key
,
value
in
initial_cache
.
items
():
for
inner_value
in
nest
.
flatten
(
value
):
if
inner_value
.
dtype
!=
self
.
dtype
:
raise
TypeError
(
"initial_cache element for key '%s' has dtype %s that does not "
"match SequenceBeamSearch's dtype of %s. Value: %s"
%
(
key
,
value
.
dtype
.
name
,
self
.
dtype
.
name
,
inner_value
))
# Current loop index (starts at 0)
# Current loop index (starts at 0)
cur_index
=
tf
.
constant
(
0
)
cur_index
=
tf
.
constant
(
0
)
...
@@ -115,7 +144,7 @@ class SequenceBeamSearch(object):
...
@@ -115,7 +144,7 @@ class SequenceBeamSearch(object):
# Create tensor for storing initial log probabilities.
# Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0
# Assume initial_ids are prob 1.0
initial_log_probs
=
tf
.
constant
(
initial_log_probs
=
tf
.
constant
(
[[
0.
]
+
[
-
float
(
"inf"
)]
*
(
self
.
beam_size
-
1
)])
[[
0.
]
+
[
-
float
(
"inf"
)]
*
(
self
.
beam_size
-
1
)]
,
dtype
=
self
.
dtype
)
alive_log_probs
=
tf
.
tile
(
initial_log_probs
,
[
self
.
batch_size
,
1
])
alive_log_probs
=
tf
.
tile
(
initial_log_probs
,
[
self
.
batch_size
,
1
])
# Expand all values stored in the dictionary to the beam size, so that each
# Expand all values stored in the dictionary to the beam size, so that each
...
@@ -127,7 +156,8 @@ class SequenceBeamSearch(object):
...
@@ -127,7 +156,8 @@ class SequenceBeamSearch(object):
finished_seq
=
tf
.
zeros
(
tf
.
shape
(
alive_seq
),
tf
.
int32
)
finished_seq
=
tf
.
zeros
(
tf
.
shape
(
alive_seq
),
tf
.
int32
)
# Set scores of the initial finished seqs to negative infinity.
# Set scores of the initial finished seqs to negative infinity.
finished_scores
=
tf
.
ones
([
self
.
batch_size
,
self
.
beam_size
])
*
-
INF
finished_scores
=
tf
.
ones
([
self
.
batch_size
,
self
.
beam_size
],
dtype
=
self
.
dtype
)
*
-
inf
(
self
.
dtype
)
# Initialize finished flags with all False values.
# Initialize finished flags with all False values.
finished_flags
=
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
],
tf
.
bool
)
finished_flags
=
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
],
tf
.
bool
)
...
@@ -185,20 +215,22 @@ class SequenceBeamSearch(object):
...
@@ -185,20 +215,22 @@ class SequenceBeamSearch(object):
not_at_max_decode_length
=
tf
.
less
(
i
,
self
.
max_decode_length
)
not_at_max_decode_length
=
tf
.
less
(
i
,
self
.
max_decode_length
)
# Calculate largest length penalty (the larger penalty, the better score).
# Calculate largest length penalty (the larger penalty, the better score).
max_length_norm
=
_length_normalization
(
self
.
alpha
,
self
.
max_decode_length
)
max_length_norm
=
_length_normalization
(
self
.
alpha
,
self
.
max_decode_length
,
dtype
=
self
.
dtype
)
# Get the best possible scores from alive sequences.
# Get the best possible scores from alive sequences.
best_alive_scores
=
alive_log_probs
[:,
0
]
/
max_length_norm
best_alive_scores
=
alive_log_probs
[:,
0
]
/
max_length_norm
# Compute worst score in finished sequences for each batch element
# Compute worst score in finished sequences for each batch element
finished_scores
*=
tf
.
cast
(
finished_flags
,
finished_scores
*=
tf
.
cast
(
finished_flags
,
tf
.
float32
)
# set filler scores to zero
self
.
dtype
)
# set filler scores to zero
lowest_finished_scores
=
tf
.
reduce_min
(
finished_scores
,
axis
=
1
)
lowest_finished_scores
=
tf
.
reduce_min
(
finished_scores
,
axis
=
1
)
# If there are no finished sequences in a batch element, then set the lowest
# If there are no finished sequences in a batch element, then set the lowest
# finished score to -INF for that element.
# finished score to -INF for that element.
finished_batches
=
tf
.
reduce_any
(
finished_flags
,
1
)
finished_batches
=
tf
.
reduce_any
(
finished_flags
,
1
)
lowest_finished_scores
+=
(
1.0
-
lowest_finished_scores
+=
((
1.0
-
tf
.
cast
(
finished_batches
,
tf
.
float32
))
*
-
INF
tf
.
cast
(
finished_batches
,
self
.
dtype
))
*
-
inf
(
self
.
dtype
))
worst_finished_score_better_than_best_alive_score
=
tf
.
reduce_all
(
worst_finished_score_better_than_best_alive_score
=
tf
.
reduce_all
(
tf
.
greater
(
lowest_finished_scores
,
best_alive_scores
)
tf
.
greater
(
lowest_finished_scores
,
best_alive_scores
)
...
@@ -319,9 +351,9 @@ class SequenceBeamSearch(object):
...
@@ -319,9 +351,9 @@ class SequenceBeamSearch(object):
Log probabilities of top alive sequences
Log probabilities of top alive sequences
Dict cache storing decoder states for top alive sequences}
Dict cache storing decoder states for top alive sequences}
"""
"""
# To prevent finished sequences from being considered, set log probs to -
INF
# To prevent finished sequences from being considered, set log probs to -
inf
new_finished_flags
=
tf
.
equal
(
new_seq
[:,
:,
-
1
],
self
.
eos_id
)
new_finished_flags
=
tf
.
equal
(
new_seq
[:,
:,
-
1
],
self
.
eos_id
)
new_log_probs
+=
tf
.
cast
(
new_finished_flags
,
tf
.
float32
)
*
-
INF
new_log_probs
+=
tf
.
cast
(
new_finished_flags
,
self
.
dtype
)
*
-
inf
(
self
.
dtype
)
top_alive_seq
,
top_alive_log_probs
,
top_alive_cache
=
_gather_topk_beams
(
top_alive_seq
,
top_alive_log_probs
,
top_alive_cache
=
_gather_topk_beams
(
[
new_seq
,
new_log_probs
,
new_cache
],
new_log_probs
,
self
.
batch_size
,
[
new_seq
,
new_log_probs
,
new_cache
],
new_log_probs
,
self
.
batch_size
,
...
@@ -361,12 +393,13 @@ class SequenceBeamSearch(object):
...
@@ -361,12 +393,13 @@ class SequenceBeamSearch(object):
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
,
1
],
tf
.
int32
)],
axis
=
2
)
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
,
1
],
tf
.
int32
)],
axis
=
2
)
# Calculate new seq scores from log probabilities.
# Calculate new seq scores from log probabilities.
length_norm
=
_length_normalization
(
self
.
alpha
,
i
+
1
)
length_norm
=
_length_normalization
(
self
.
alpha
,
i
+
1
,
dtype
=
self
.
dtype
)
new_scores
=
new_log_probs
/
length_norm
new_scores
=
new_log_probs
/
length_norm
# Set the scores of the still-alive seq in new_seq to large negative values.
# Set the scores of the still-alive seq in new_seq to large negative values.
new_finished_flags
=
tf
.
equal
(
new_seq
[:,
:,
-
1
],
self
.
eos_id
)
new_finished_flags
=
tf
.
equal
(
new_seq
[:,
:,
-
1
],
self
.
eos_id
)
new_scores
+=
(
1.
-
tf
.
cast
(
new_finished_flags
,
tf
.
float32
))
*
-
INF
new_scores
+=
((
1.
-
tf
.
cast
(
new_finished_flags
,
self
.
dtype
))
*
-
inf
(
self
.
dtype
))
# Combine sequences, scores, and flags.
# Combine sequences, scores, and flags.
finished_seq
=
tf
.
concat
([
finished_seq
,
new_seq
],
axis
=
1
)
finished_seq
=
tf
.
concat
([
finished_seq
,
new_seq
],
axis
=
1
)
...
@@ -422,9 +455,9 @@ def _log_prob_from_logits(logits):
...
@@ -422,9 +455,9 @@ def _log_prob_from_logits(logits):
return
logits
-
tf
.
reduce_logsumexp
(
logits
,
axis
=
2
,
keepdims
=
True
)
return
logits
-
tf
.
reduce_logsumexp
(
logits
,
axis
=
2
,
keepdims
=
True
)
def
_length_normalization
(
alpha
,
length
):
def
_length_normalization
(
alpha
,
length
,
dtype
=
tf
.
float32
):
"""Return length normalization factor."""
"""Return length normalization factor."""
return
tf
.
pow
(((
5.
+
tf
.
cast
(
length
,
tf
.
float32
))
/
6.
),
alpha
)
return
tf
.
pow
(((
5.
+
tf
.
cast
(
length
,
dtype
))
/
6.
),
alpha
)
def
_expand_to_beam_size
(
tensor
,
beam_size
):
def
_expand_to_beam_size
(
tensor
,
beam_size
):
...
...
official/transformer/v2/beam_search.py
View file @
cb8ce606
...
@@ -57,7 +57,7 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
...
@@ -57,7 +57,7 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
def
sequence_beam_search
(
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
):
alpha
,
max_decode_length
,
eos_id
,
dtype
=
"float32"
):
"""Search for sequence of subtoken ids with the largest probability.
"""Search for sequence of subtoken ids with the largest probability.
Args:
Args:
...
@@ -76,7 +76,8 @@ def sequence_beam_search(
...
@@ -76,7 +76,8 @@ def sequence_beam_search(
beam_size: int number of beams
beam_size: int number of beams
alpha: float defining the strength of length normalization
alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence
max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished
eos_id: int id of eos token, used to determine when a sequence has finished,
dtype: The dtype to use.
Returns:
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
Top decoded sequences [batch_size, beam_size, max_decode_length]
...
@@ -85,10 +86,12 @@ def sequence_beam_search(
...
@@ -85,10 +86,12 @@ def sequence_beam_search(
batch_size
=
tf
.
shape
(
initial_ids
)[
0
]
batch_size
=
tf
.
shape
(
initial_ids
)[
0
]
if
misc
.
is_v2
():
if
misc
.
is_v2
():
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
)
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
)
else
:
else
:
sbs
=
v1
.
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
sbs
=
v1
.
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
)
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
...
official/transformer/v2/embedding_layer.py
View file @
cb8ce606
...
@@ -24,14 +24,24 @@ import tensorflow as tf
...
@@ -24,14 +24,24 @@ import tensorflow as tf
class
EmbeddingSharedWeights
(
tf
.
keras
.
layers
.
Layer
):
class
EmbeddingSharedWeights
(
tf
.
keras
.
layers
.
Layer
):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def
__init__
(
self
,
vocab_size
,
hidden_size
):
def
__init__
(
self
,
vocab_size
,
hidden_size
,
dtype
=
None
):
"""Specify characteristic parameters of embedding layer.
"""Specify characteristic parameters of embedding layer.
Args:
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
"""
"""
super
(
EmbeddingSharedWeights
,
self
).
__init__
()
if
dtype
==
tf
.
float16
:
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"float16_with_float32_vars"
)
super
(
EmbeddingSharedWeights
,
self
).
__init__
(
dtype
=
dtype
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -78,8 +88,8 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
...
@@ -78,8 +88,8 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
"""Applies embedding based on inputs tensor."""
"""Applies embedding based on inputs tensor."""
with
tf
.
name_scope
(
"embedding"
):
with
tf
.
name_scope
(
"embedding"
):
# Create binary mask of size [batch_size, length]
# Create binary mask of size [batch_size, length]
mask
=
tf
.
cast
(
tf
.
not_equal
(
inputs
,
0
),
tf
.
float32
)
embeddings
=
tf
.
gather
(
self
.
shared_weights
,
inputs
)
embeddings
=
tf
.
gather
(
self
.
shared_weights
,
inputs
)
mask
=
tf
.
cast
(
tf
.
not_equal
(
inputs
,
0
),
embeddings
.
dtype
)
embeddings
*=
tf
.
expand_dims
(
mask
,
-
1
)
embeddings
*=
tf
.
expand_dims
(
mask
,
-
1
)
# Scale embedding by the sqrt of the hidden size
# Scale embedding by the sqrt of the hidden size
embeddings
*=
self
.
hidden_size
**
0.5
embeddings
*=
self
.
hidden_size
**
0.5
...
...
official/transformer/v2/transformer.py
View file @
cb8ce606
...
@@ -32,6 +32,11 @@ from official.transformer.v2 import ffn_layer
...
@@ -32,6 +32,11 @@ from official.transformer.v2 import ffn_layer
from
official.transformer.v2
import
metrics
from
official.transformer.v2
import
metrics
# Disable the not-callable lint error, since it claims many objects are not
# callable when they actually are.
# pylint: disable=not-callable
def
create_model
(
params
,
is_train
):
def
create_model
(
params
,
is_train
):
"""Creates transformer model."""
"""Creates transformer model."""
with
tf
.
name_scope
(
"model"
):
with
tf
.
name_scope
(
"model"
):
...
@@ -80,7 +85,7 @@ class Transformer(tf.keras.Model):
...
@@ -80,7 +85,7 @@ class Transformer(tf.keras.Model):
super
(
Transformer
,
self
).
__init__
(
name
=
name
)
super
(
Transformer
,
self
).
__init__
(
name
=
name
)
self
.
params
=
params
self
.
params
=
params
self
.
embedding_softmax_layer
=
embedding_layer
.
EmbeddingSharedWeights
(
self
.
embedding_softmax_layer
=
embedding_layer
.
EmbeddingSharedWeights
(
params
[
"vocab_size"
],
params
[
"hidden_size"
])
params
[
"vocab_size"
],
params
[
"hidden_size"
]
,
dtype
=
params
[
"dtype"
]
)
self
.
encoder_stack
=
EncoderStack
(
params
)
self
.
encoder_stack
=
EncoderStack
(
params
)
self
.
decoder_stack
=
DecoderStack
(
params
)
self
.
decoder_stack
=
DecoderStack
(
params
)
...
@@ -216,8 +221,9 @@ class Transformer(tf.keras.Model):
...
@@ -216,8 +221,9 @@ class Transformer(tf.keras.Model):
timing_signal
=
model_utils
.
get_position_encoding
(
timing_signal
=
model_utils
.
get_position_encoding
(
max_decode_length
+
1
,
self
.
params
[
"hidden_size"
])
max_decode_length
+
1
,
self
.
params
[
"hidden_size"
])
timing_signal
=
tf
.
cast
(
timing_signal
,
self
.
params
[
"dtype"
])
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
max_decode_length
)
max_decode_length
,
dtype
=
self
.
params
[
"dtype"
]
)
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
"""Generate logits for next potential IDs.
"""Generate logits for next potential IDs.
...
@@ -257,12 +263,11 @@ class Transformer(tf.keras.Model):
...
@@ -257,12 +263,11 @@ class Transformer(tf.keras.Model):
def
predict
(
self
,
encoder_outputs
,
encoder_decoder_attention_bias
,
training
):
def
predict
(
self
,
encoder_outputs
,
encoder_decoder_attention_bias
,
training
):
"""Return predicted sequence."""
"""Return predicted sequence."""
# Currently, we always do prediction in float32.
# TODO(reedwm): Add float16 support.
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
tf
.
float32
)
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
max_decode_length
=
input_length
+
self
.
params
[
"extra_decode_length"
]
max_decode_length
=
input_length
+
self
.
params
[
"extra_decode_length"
]
encoder_decoder_attention_bias
=
tf
.
cast
(
encoder_decoder_attention_bias
,
self
.
params
[
"dtype"
])
symbols_to_logits_fn
=
self
.
_get_symbols_to_logits_fn
(
symbols_to_logits_fn
=
self
.
_get_symbols_to_logits_fn
(
max_decode_length
,
training
)
max_decode_length
,
training
)
...
@@ -274,8 +279,10 @@ class Transformer(tf.keras.Model):
...
@@ -274,8 +279,10 @@ class Transformer(tf.keras.Model):
# pylint: disable=g-complex-comprehension
# pylint: disable=g-complex-comprehension
cache
=
{
cache
=
{
"layer_%d"
%
layer
:
{
"layer_%d"
%
layer
:
{
"k"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]]),
"k"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]],
"v"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]])
dtype
=
self
.
params
[
"dtype"
]),
"v"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]],
dtype
=
self
.
params
[
"dtype"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
}
}
# pylint: enable=g-complex-comprehension
# pylint: enable=g-complex-comprehension
...
@@ -293,7 +300,8 @@ class Transformer(tf.keras.Model):
...
@@ -293,7 +300,8 @@ class Transformer(tf.keras.Model):
beam_size
=
self
.
params
[
"beam_size"
],
beam_size
=
self
.
params
[
"beam_size"
],
alpha
=
self
.
params
[
"alpha"
],
alpha
=
self
.
params
[
"alpha"
],
max_decode_length
=
max_decode_length
,
max_decode_length
=
max_decode_length
,
eos_id
=
EOS_ID
)
eos_id
=
EOS_ID
,
dtype
=
self
.
params
[
"dtype"
])
# Get the top sequence for each batch element
# Get the top sequence for each batch element
top_decoded_ids
=
decoded_ids
[:,
0
,
1
:]
top_decoded_ids
=
decoded_ids
[:,
0
,
1
:]
...
...
official/transformer/v2/transformer_main_test.py
View file @
cb8ce606
...
@@ -95,6 +95,12 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -95,6 +95,12 @@ class TransformerTaskTest(tf.test.TestCase):
t
=
tm
.
TransformerTask
(
FLAGS
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
t
.
train
()
@
unittest
.
skipUnless
(
tf
.
test
.
is_built_with_cuda
(),
'requires GPU'
)
def
test_train_fp16
(
self
):
FLAGS
.
dtype
=
'fp16'
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
@
unittest
.
skipUnless
(
tf
.
test
.
is_built_with_cuda
(),
'requires GPU'
)
@
unittest
.
skipUnless
(
tf
.
test
.
is_built_with_cuda
(),
'requires GPU'
)
def
test_train_2_gpu
(
self
):
def
test_train_2_gpu
(
self
):
if
context
.
num_gpus
()
<
2
:
if
context
.
num_gpus
()
<
2
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment