Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
640ff472
Commit
640ff472
authored
Aug 22, 2019
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 264853703
parent
4a1354fe
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
415 additions
and
135 deletions
+415
-135
official/transformer/model/beam_search.py
official/transformer/model/beam_search.py
+126
-35
official/transformer/v2/attention_layer.py
official/transformer/v2/attention_layer.py
+45
-28
official/transformer/v2/beam_search.py
official/transformer/v2/beam_search.py
+36
-21
official/transformer/v2/misc.py
official/transformer/v2/misc.py
+23
-0
official/transformer/v2/transformer.py
official/transformer/v2/transformer.py
+67
-22
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+55
-10
official/transformer/v2/translate.py
official/transformer/v2/translate.py
+63
-19
No files found.
official/transformer/model/beam_search.py
View file @
640ff472
...
@@ -79,8 +79,41 @@ class _StateKeys(object):
...
@@ -79,8 +79,41 @@ class _StateKeys(object):
class
SequenceBeamSearch
(
object
):
class
SequenceBeamSearch
(
object
):
"""Implementation of beam search loop."""
"""Implementation of beam search loop."""
def
__init__
(
self
,
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
def
__init__
(
self
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
=
tf
.
float32
):
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
=
tf
.
float32
):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and the updated cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
updated cache -> A nested dictionary with the same structure as the
input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
batch_size: An integer, the decode batch size.
beam_size: An integer, number of beams for beam search.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum number of steps to decode
a sequence.
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
...
@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
self
.
alpha
=
alpha
self
.
alpha
=
alpha
self
.
max_decode_length
=
max_decode_length
self
.
max_decode_length
=
max_decode_length
self
.
eos_id
=
eos_id
self
.
eos_id
=
eos_id
self
.
padded_decode
=
padded_decode
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
def
search
(
self
,
initial_ids
,
initial_cache
):
def
search
(
self
,
initial_ids
,
initial_cache
):
...
@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
...
@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
# Create alive sequence with shape [batch_size, beam_size, 1]
# Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq
=
_expand_to_beam_size
(
initial_ids
,
self
.
beam_size
)
alive_seq
=
_expand_to_beam_size
(
initial_ids
,
self
.
beam_size
)
alive_seq
=
tf
.
expand_dims
(
alive_seq
,
axis
=
2
)
alive_seq
=
tf
.
expand_dims
(
alive_seq
,
axis
=
2
)
if
self
.
padded_decode
:
alive_seq
=
tf
.
tile
(
alive_seq
,
[
1
,
1
,
self
.
max_decode_length
+
1
])
# Create tensor for storing initial log probabilities.
# Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0
# Assume initial_ids are prob 1.0
...
@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
...
@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
# 1) the dimension's value is a tensor that remains the same but may
# 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size).
# depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations.
# 2) the dimension may have different values on different iterations.
state_shape_invariants
=
{
if
self
.
padded_decode
:
_StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
state_shape_invariants
=
{
_StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
CUR_INDEX
:
_StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
tf
.
TensorShape
([]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_StateKeys
.
ALIVE_SEQ
:
_get_shape_keep_last_dim
,
alive_cache
),
tf
.
TensorShape
(
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
[
self
.
batch_size
,
self
.
beam_size
,
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
self
.
max_decode_length
+
1
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
])
_StateKeys
.
ALIVE_LOG_PROBS
:
}
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_get_shape
,
alive_cache
),
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
(
[
self
.
batch_size
,
self
.
beam_size
,
self
.
max_decode_length
+
1
]),
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
self
.
batch_size
,
self
.
beam_size
])
}
else
:
state_shape_invariants
=
{
_StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
_StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
ALIVE_CACHE
:
nest
.
map_structure
(
_get_shape_keep_last_dim
,
alive_cache
),
_StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
self
.
beam_size
,
None
]),
_StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
None
,
self
.
beam_size
]),
_StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
None
,
self
.
beam_size
])
}
return
state
,
state_shape_invariants
return
state
,
state_shape_invariants
...
@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
...
@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
# Get logits for the next candidate IDs for the alive sequences. Get the new
# Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time.
# cache values at the same time.
flat_ids
=
_flatten_beam_dim
(
alive_seq
)
# [batch_size * beam_size]
if
self
.
padded_decode
:
flat_ids
=
tf
.
reshape
(
tf
.
slice
(
alive_seq
,
[
0
,
0
,
i
],
[
self
.
batch_size
,
self
.
beam_size
,
1
]),
[
self
.
batch_size
*
self
.
beam_size
,
-
1
])
else
:
flat_ids
=
_flatten_beam_dim
(
alive_seq
)
# [batch_size * beam_size]
flat_cache
=
nest
.
map_structure
(
_flatten_beam_dim
,
alive_cache
)
flat_cache
=
nest
.
map_structure
(
_flatten_beam_dim
,
alive_cache
)
flat_logits
,
flat_cache
=
self
.
symbols_to_logits_fn
(
flat_ids
,
i
,
flat_cache
)
flat_logits
,
flat_cache
=
self
.
symbols_to_logits_fn
(
flat_ids
,
i
,
flat_cache
)
...
@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
...
@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
# Append the most probable IDs to the topk sequences
# Append the most probable IDs to the topk sequences
topk_ids
=
topk_indices
%
self
.
vocab_size
topk_ids
=
topk_indices
%
self
.
vocab_size
topk_ids
=
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)
if
self
.
padded_decode
:
topk_seq
=
tf
.
concat
([
topk_seq
,
topk_ids
],
axis
=
2
)
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
2
,
0
,
1
])
topk_seq
=
tf
.
tensor_scatter_update
(
topk_seq
,
[
i
+
1
],
topk_ids
)
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
1
,
2
,
0
])
else
:
topk_ids
=
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)
topk_seq
=
tf
.
concat
([
topk_seq
,
topk_ids
],
axis
=
2
)
return
topk_seq
,
topk_log_probs
,
new_cache
return
topk_seq
,
topk_log_probs
,
new_cache
def
_get_new_alive_state
(
self
,
new_seq
,
new_log_probs
,
new_cache
):
def
_get_new_alive_state
(
self
,
new_seq
,
new_log_probs
,
new_cache
):
...
@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
...
@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
# First append a column of 0-ids to finished_seq to increment the length.
# First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1]
# New shape of finished_seq: [batch_size, beam_size, i + 1]
finished_seq
=
tf
.
concat
(
if
not
self
.
padded_decode
:
[
finished_seq
,
finished_seq
=
tf
.
concat
([
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
,
1
],
tf
.
int32
)],
axis
=
2
)
finished_seq
,
tf
.
zeros
([
self
.
batch_size
,
self
.
beam_size
,
1
],
tf
.
int32
)
],
axis
=
2
)
# Calculate new seq scores from log probabilities.
# Calculate new seq scores from log probabilities.
length_norm
=
_length_normalization
(
self
.
alpha
,
i
+
1
,
dtype
=
self
.
dtype
)
length_norm
=
_length_normalization
(
self
.
alpha
,
i
+
1
,
dtype
=
self
.
dtype
)
...
@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
...
@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
def
sequence_beam_search
(
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
):
alpha
,
max_decode_length
,
eos_id
,
padded_decode
=
False
):
"""Search for sequence of subtoken ids with the largest probability.
"""Search for sequence of subtoken ids with the largest probability.
Args:
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar)
index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache.
The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size]
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache
new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item.
inputted cache.
int32 tensor with shape [batch_size]
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information
each batch item.
vocab_size: int size of tokens
initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams
information.
alpha: float defining the strength of length normalization
vocab_size: An integer, the size of the vocabulary, used for topk
max_decode_length: maximum length to decoded sequence
computation.
eos_id: int id of eos token, used to determine when a sequence has finished
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
Returns:
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
sequence scores [batch_size, beam_size]
"""
"""
batch_size
=
tf
.
shape
(
initial_ids
)[
0
]
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
)
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
...
@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
return
tf
.
TensorShape
(
shape_list
)
return
tf
.
TensorShape
(
shape_list
)
def
_get_shape
(
tensor
):
"""Return the shape of the input tensor."""
return
tf
.
TensorShape
(
_shape_list
(
tensor
))
def
_flatten_beam_dim
(
tensor
):
def
_flatten_beam_dim
(
tensor
):
"""Reshapes first two dimensions in to single dimension.
"""Reshapes first two dimensions in to single dimension.
...
...
official/transformer/v2/attention_layer.py
View file @
640ff472
...
@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
...
@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
x
=
tf
.
transpose
(
x
,
[
0
,
2
,
1
,
3
])
# --> [batch, length, num_heads, depth]
x
=
tf
.
transpose
(
x
,
[
0
,
2
,
1
,
3
])
# --> [batch, length, num_heads, depth]
return
tf
.
reshape
(
x
,
[
batch_size
,
length
,
self
.
hidden_size
])
return
tf
.
reshape
(
x
,
[
batch_size
,
length
,
self
.
hidden_size
])
def
call
(
self
,
x
,
y
,
bias
,
training
,
cache
=
None
):
def
call
(
self
,
x
,
y
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
"""Apply attention mechanism to x and y.
"""Apply attention mechanism to x and y.
Args:
Args:
x: a tensor with shape [batch_size, length_x, hidden_size]
x: A tensor with shape [batch_size, length_x, hidden_size].
y: a tensor with shape [batch_size, length_y, hidden_size]
y: A tensor with shape [batch_size, length_y, hidden_size].
bias: attention bias that will be added to the result of the dot product.
bias: A bool, the attention bias that will be added to the result of the
training: boolean, whether in training mode or not.
dot product.
cache: (Used during prediction) dictionary with tensors containing results
training: A bool, whether in training mode or not.
of previous attentions. The dictionary must have the items:
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
{"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}
"v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length.
where i is the current decoded length.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns:
Returns:
Attention layer output with shape [batch_size, length_x, hidden_size]
Attention layer output with shape [batch_size, length_x, hidden_size]
"""
"""
# Linearly project the query
(q)
, key
(k)
and value
(v)
using different
# Linearly project the query, key and value using different
learned
#
learned
projections. This is in preparation of splitting them into
# projections. This is in preparation of splitting them into
multiple
#
multiple
heads. Multi-head attention uses multiple queries, keys, and
# heads. Multi-head attention uses multiple queries, keys, and
values
#
values
rather than regular attention (which uses a single q, k, v).
# rather than regular attention (which uses a single q
uery
, k
ey
, v
alue
).
q
=
self
.
q_dense_layer
(
x
)
q
uery
=
self
.
q_dense_layer
(
x
)
k
=
self
.
k_dense_layer
(
y
)
k
ey
=
self
.
k_dense_layer
(
y
)
v
=
self
.
v_dense_layer
(
y
)
v
alue
=
self
.
v_dense_layer
(
y
)
if
cache
is
not
None
:
if
cache
is
not
None
:
# Combine cached keys and values with new keys and values.
# Combine cached keys and values with new keys and values.
k
=
tf
.
concat
([
tf
.
cast
(
cache
[
"k"
],
k
.
dtype
),
k
],
axis
=
1
)
if
decode_loop_step
is
not
None
:
v
=
tf
.
concat
([
tf
.
cast
(
cache
[
"v"
],
k
.
dtype
),
v
],
axis
=
1
)
cache_k_shape
=
cache
[
"k"
].
shape
.
as_list
()
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
cache_k_shape
[
1
],
dtype
=
key
.
dtype
),
[
1
,
cache_k_shape
[
1
],
1
])
key
=
cache
[
"k"
]
+
key
*
indices
cache_v_shape
=
cache
[
"v"
].
shape
.
as_list
()
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
cache_v_shape
[
1
],
dtype
=
value
.
dtype
),
[
1
,
cache_v_shape
[
1
],
1
])
value
=
cache
[
"v"
]
+
value
*
indices
else
:
key
=
tf
.
concat
([
tf
.
cast
(
cache
[
"k"
],
key
.
dtype
),
key
],
axis
=
1
)
value
=
tf
.
concat
([
tf
.
cast
(
cache
[
"v"
],
value
.
dtype
),
value
],
axis
=
1
)
# Update cache
# Update cache
cache
[
"k"
]
=
k
cache
[
"k"
]
=
k
ey
cache
[
"v"
]
=
v
cache
[
"v"
]
=
v
alue
# Split q, k, v into heads.
# Split q
uery
, k
ey
, v
alue
into heads.
q
=
self
.
split_heads
(
q
)
q
uery
=
self
.
split_heads
(
q
uery
)
k
=
self
.
split_heads
(
k
)
k
ey
=
self
.
split_heads
(
k
ey
)
v
=
self
.
split_heads
(
v
)
v
alue
=
self
.
split_heads
(
v
alue
)
# Scale q to prevent the dot product between q and k from growing too large.
# Scale query to prevent the dot product between query and key from growing
# too large.
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
q
*=
depth
**
-
0.5
q
uery
*=
depth
**
-
0.5
# Calculate dot product attention
# Calculate dot product attention
logits
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
q
uery
,
k
ey
,
transpose_b
=
True
)
logits
+=
bias
logits
+=
bias
# Note that softmax internally performs math operations using float32
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
# for numeric stability. When training with float16, we keep the input
...
@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
...
@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
weights
=
tf
.
nn
.
softmax
(
logits
,
name
=
"attention_weights"
)
weights
=
tf
.
nn
.
softmax
(
logits
,
name
=
"attention_weights"
)
if
training
:
if
training
:
weights
=
tf
.
nn
.
dropout
(
weights
,
rate
=
self
.
attention_dropout
)
weights
=
tf
.
nn
.
dropout
(
weights
,
rate
=
self
.
attention_dropout
)
attention_output
=
tf
.
matmul
(
weights
,
v
)
attention_output
=
tf
.
matmul
(
weights
,
v
alue
)
# Recombine heads --> [batch_size, length, hidden_size]
# Recombine heads --> [batch_size, length, hidden_size]
attention_output
=
self
.
combine_heads
(
attention_output
)
attention_output
=
self
.
combine_heads
(
attention_output
)
...
@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
...
@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
class
SelfAttention
(
Attention
):
class
SelfAttention
(
Attention
):
"""Multiheaded self-attention layer."""
"""Multiheaded self-attention layer."""
def
call
(
self
,
x
,
bias
,
training
,
cache
=
None
):
def
call
(
self
,
x
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
return
super
(
SelfAttention
,
self
).
call
(
x
,
x
,
bias
,
training
,
cache
)
return
super
(
SelfAttention
,
self
).
call
(
x
,
x
,
bias
,
training
,
cache
,
decode_loop_step
)
official/transformer/v2/beam_search.py
View file @
640ff472
...
@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
...
@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
return
finished_seq
,
finished_scores
return
finished_seq
,
finished_scores
def
sequence_beam_search
(
def
sequence_beam_search
(
symbols_to_logits_fn
,
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
initial_ids
,
alpha
,
max_decode_length
,
eos_id
,
dtype
=
"float32"
):
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
=
"float32"
):
"""Search for sequence of subtoken ids with the largest probability.
"""Search for sequence of subtoken ids with the largest probability.
Args:
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar)
index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache.
The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size]
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache
new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item.
inputted cache.
int32 tensor with shape [batch_size]
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information
each batch item.
vocab_size: int size of tokens
initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams
information.
alpha: float defining the strength of length normalization
vocab_size: An integer, the size of tokens.
max_decode_length: maximum length to decoded sequence
beam_size: An integer, the number of beams.
eos_id: int id of eos token, used to determine when a sequence has finished,
alpha: A float, defining the strength of length normalization.
dtype: The dtype to use.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns:
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
sequence scores [batch_size, beam_size]
"""
"""
batch_size
=
tf
.
shape
(
initial_ids
)[
0
]
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
if
misc
.
is_v2
():
if
misc
.
is_v2
():
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
)
padded_decode
,
dtype
)
else
:
else
:
sbs
=
v1
.
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
sbs
=
v1
.
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
dtype
)
padded_decode
,
dtype
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
...
official/transformer/v2/misc.py
View file @
640ff472
...
@@ -191,6 +191,29 @@ def define_transformer_flags():
...
@@ -191,6 +191,29 @@ def define_transformer_flags():
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
'Whether the model runs in 2VM mode, Headless server and unit test '
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'
))
'all use 1VM config.'
))
flags
.
DEFINE_integer
(
name
=
'decode_batch_size'
,
default
=
32
,
help
=
flags_core
.
help_wrap
(
'Global batch size used for Transformer autoregressive decoding on '
'TPU.'
))
flags
.
DEFINE_integer
(
name
=
'decode_max_length'
,
default
=
97
,
help
=
flags_core
.
help_wrap
(
'Max sequence length of the decode/eval data. This is used by '
'Transformer autoregressive decoding on TPU to have minimum '
'paddings.'
))
flags
.
DEFINE_bool
(
name
=
'padded_decode'
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
'Whether the autoregressive decoding runs with input data padded to '
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
'set due the static shape requirement. Although CPU/GPU could also '
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'
))
flags_core
.
set_defaults
(
data_dir
=
'/tmp/translate_ende'
,
flags_core
.
set_defaults
(
data_dir
=
'/tmp/translate_ende'
,
model_dir
=
'/tmp/transformer_model'
,
model_dir
=
'/tmp/transformer_model'
,
...
...
official/transformer/v2/transformer.py
View file @
640ff472
...
@@ -112,11 +112,22 @@ class Transformer(tf.keras.Model):
...
@@ -112,11 +112,22 @@ class Transformer(tf.keras.Model):
outputs: [batch_size, decoded length]
outputs: [batch_size, decoded length]
scores: [batch_size, float]}
scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32.
Even when float16 is used, the output tensor(s) are always float32.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
"""
if
len
(
inputs
)
==
2
:
if
len
(
inputs
)
==
2
:
inputs
,
targets
=
inputs
[
0
],
inputs
[
1
]
inputs
,
targets
=
inputs
[
0
],
inputs
[
1
]
else
:
else
:
inputs
,
targets
=
inputs
[
0
],
None
inputs
,
targets
=
inputs
[
0
],
None
if
self
.
params
[
"padded_decode"
]:
if
not
self
.
params
[
"num_replicas"
]:
raise
NotImplementedError
(
"Padded decoding on CPU/GPUs is not supported."
)
decode_batch_size
=
int
(
self
.
params
[
"decode_batch_size"
]
/
self
.
params
[
"num_replicas"
])
inputs
=
tf
.
reshape
(
inputs
,
[
decode_batch_size
,
self
.
params
[
"decode_max_length"
]])
# Variance scaling is used here because it seems to work in many problems.
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
# Other reasonable initializers may also work just as well.
...
@@ -225,13 +236,14 @@ class Transformer(tf.keras.Model):
...
@@ -225,13 +236,14 @@ class Transformer(tf.keras.Model):
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
max_decode_length
,
dtype
=
self
.
params
[
"dtype"
])
max_decode_length
,
dtype
=
self
.
params
[
"dtype"
])
# TODO(b/139770046): Refactor code with better naming of i.
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
"""Generate logits for next potential IDs.
"""Generate logits for next potential IDs.
Args:
Args:
ids: Current decoded sequences. int tensor with shape [batch_size *
ids: Current decoded sequences. int tensor with shape [batch_size *
beam_size, i + 1]
beam_size, i + 1]
.
i: Loop index
i: Loop index
.
cache: dictionary of values storing the encoder output, encoder-decoder
cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
attention bias, and previous decoder attention values.
...
@@ -245,16 +257,29 @@ class Transformer(tf.keras.Model):
...
@@ -245,16 +257,29 @@ class Transformer(tf.keras.Model):
# Preprocess decoder input by getting embeddings and adding timing signal.
# Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input
=
self
.
embedding_softmax_layer
(
decoder_input
)
decoder_input
=
self
.
embedding_softmax_layer
(
decoder_input
)
decoder_input
+=
timing_signal
[
i
:
i
+
1
]
self_attention_bias
=
decoder_self_attention_bias
[:,
:,
i
:
i
+
1
,
:
i
+
1
]
if
self
.
params
[
"padded_decode"
]:
timing_signal_shape
=
timing_signal
.
shape
.
as_list
()
decoder_input
+=
tf
.
slice
(
timing_signal
,
[
i
,
0
],
[
1
,
timing_signal_shape
[
1
]])
bias_shape
=
decoder_self_attention_bias
.
shape
.
as_list
()
self_attention_bias
=
tf
.
slice
(
decoder_self_attention_bias
,
[
0
,
0
,
i
,
0
],
[
bias_shape
[
0
],
bias_shape
[
1
],
1
,
bias_shape
[
3
]])
else
:
decoder_input
+=
timing_signal
[
i
:
i
+
1
]
self_attention_bias
=
decoder_self_attention_bias
[:,
:,
i
:
i
+
1
,
:
i
+
1
]
decoder_outputs
=
self
.
decoder_stack
(
decoder_outputs
=
self
.
decoder_stack
(
decoder_input
,
decoder_input
,
cache
.
get
(
"encoder_outputs"
),
cache
.
get
(
"encoder_outputs"
),
self_attention_bias
,
self_attention_bias
,
cache
.
get
(
"encoder_decoder_attention_bias"
),
cache
.
get
(
"encoder_decoder_attention_bias"
),
training
=
training
,
training
=
training
,
cache
=
cache
)
cache
=
cache
,
decode_loop_step
=
i
if
self
.
params
[
"padded_decode"
]
else
None
)
logits
=
self
.
embedding_softmax_layer
(
decoder_outputs
,
mode
=
"linear"
)
logits
=
self
.
embedding_softmax_layer
(
decoder_outputs
,
mode
=
"linear"
)
logits
=
tf
.
squeeze
(
logits
,
axis
=
[
1
])
logits
=
tf
.
squeeze
(
logits
,
axis
=
[
1
])
return
logits
,
cache
return
logits
,
cache
...
@@ -263,8 +288,12 @@ class Transformer(tf.keras.Model):
...
@@ -263,8 +288,12 @@ class Transformer(tf.keras.Model):
def
predict
(
self
,
encoder_outputs
,
encoder_decoder_attention_bias
,
training
):
def
predict
(
self
,
encoder_outputs
,
encoder_decoder_attention_bias
,
training
):
"""Return predicted sequence."""
"""Return predicted sequence."""
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
if
self
.
params
[
"padded_decode"
]:
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
batch_size
=
encoder_outputs
.
shape
.
as_list
()[
0
]
input_length
=
encoder_outputs
.
shape
.
as_list
()[
1
]
else
:
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
max_decode_length
=
input_length
+
self
.
params
[
"extra_decode_length"
]
max_decode_length
=
input_length
+
self
.
params
[
"extra_decode_length"
]
encoder_decoder_attention_bias
=
tf
.
cast
(
encoder_decoder_attention_bias
,
encoder_decoder_attention_bias
=
tf
.
cast
(
encoder_decoder_attention_bias
,
self
.
params
[
"dtype"
])
self
.
params
[
"dtype"
])
...
@@ -277,12 +306,20 @@ class Transformer(tf.keras.Model):
...
@@ -277,12 +306,20 @@ class Transformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer.
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
# pylint: disable=g-complex-comprehension
init_decode_length
=
(
max_decode_length
if
self
.
params
[
"padded_decode"
]
else
0
)
cache
=
{
cache
=
{
"layer_%d"
%
layer
:
{
"layer_%d"
%
layer
:
{
"k"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]],
"k"
:
dtype
=
self
.
params
[
"dtype"
]),
tf
.
zeros
([
"v"
:
tf
.
zeros
([
batch_size
,
0
,
self
.
params
[
"hidden_size"
]],
batch_size
,
init_decode_length
,
self
.
params
[
"hidden_size"
]
dtype
=
self
.
params
[
"dtype"
])
],
dtype
=
self
.
params
[
"dtype"
]),
"v"
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
self
.
params
[
"hidden_size"
]
],
dtype
=
self
.
params
[
"dtype"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
}
}
# pylint: enable=g-complex-comprehension
# pylint: enable=g-complex-comprehension
...
@@ -301,6 +338,7 @@ class Transformer(tf.keras.Model):
...
@@ -301,6 +338,7 @@ class Transformer(tf.keras.Model):
alpha
=
self
.
params
[
"alpha"
],
alpha
=
self
.
params
[
"alpha"
],
max_decode_length
=
max_decode_length
,
max_decode_length
=
max_decode_length
,
eos_id
=
EOS_ID
,
eos_id
=
EOS_ID
,
padded_decode
=
self
.
params
[
"padded_decode"
],
dtype
=
self
.
params
[
"dtype"
])
dtype
=
self
.
params
[
"dtype"
])
# Get the top sequence for each batch element
# Get the top sequence for each batch element
...
@@ -505,22 +543,28 @@ class DecoderStack(tf.keras.layers.Layer):
...
@@ -505,22 +543,28 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_self_attention_bias
,
decoder_self_attention_bias
,
attention_bias
,
attention_bias
,
training
,
training
,
cache
=
None
):
cache
=
None
,
decode_loop_step
=
None
):
"""Return the output of the decoder layer stacks.
"""Return the output of the decoder layer stacks.
Args:
Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size]
decoder_inputs: A tensor with shape
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size]
[batch_size, target_length, hidden_size].
decoder_self_attention_bias: bias for decoder self-attention layer. [1, 1,
encoder_outputs: A tensor with shape
target_len, target_length]
[batch_size, input_length, hidden_size]
attention_bias: bias for encoder-decoder attention layer. [batch_size, 1,
decoder_self_attention_bias: A tensor with shape
1, input_length]
[1, 1, target_len, target_length], the bias for decoder self-attention
training: boolean, whether in training mode or not.
layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
the bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels],
{layer_n: {"k":
A
tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]},
"v":
A
tensor with shape [batch_size, i, value_channels]},
...}
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns:
Returns:
Output of decoder layer stack.
Output of decoder layer stack.
...
@@ -540,7 +584,8 @@ class DecoderStack(tf.keras.layers.Layer):
...
@@ -540,7 +584,8 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_inputs
,
decoder_inputs
,
decoder_self_attention_bias
,
decoder_self_attention_bias
,
training
=
training
,
training
=
training
,
cache
=
layer_cache
)
cache
=
layer_cache
,
decode_loop_step
=
decode_loop_step
)
with
tf
.
name_scope
(
"encdec_attention"
):
with
tf
.
name_scope
(
"encdec_attention"
):
decoder_inputs
=
enc_dec_attention_layer
(
decoder_inputs
=
enc_dec_attention_layer
(
decoder_inputs
,
decoder_inputs
,
...
...
official/transformer/v2/transformer_main.py
View file @
640ff472
...
@@ -52,18 +52,40 @@ BLEU_DIR = "bleu"
...
@@ -52,18 +52,40 @@ BLEU_DIR = "bleu"
_SINGLE_SAMPLE
=
1
_SINGLE_SAMPLE
=
1
def
translate_and_compute_bleu
(
model
,
subtokenizer
,
bleu_source
,
bleu_ref
):
def
translate_and_compute_bleu
(
model
,
"""Translate file and report the cased and uncased bleu scores."""
params
,
subtokenizer
,
bleu_source
,
bleu_ref
,
distribution_strategy
=
None
):
"""Translate file and report the cased and uncased bleu scores.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
subtokenizer: A subtokenizer object, used for encoding and decoding source
and translated lines.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
# Create temporary file to store translation.
# Create temporary file to store translation.
tmp
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
tmp
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
tmp_filename
=
tmp
.
name
tmp_filename
=
tmp
.
name
translate
.
translate_file
(
translate
.
translate_file
(
model
,
model
,
params
,
subtokenizer
,
subtokenizer
,
bleu_source
,
bleu_source
,
output_file
=
tmp_filename
,
output_file
=
tmp_filename
,
print_all_translations
=
False
)
print_all_translations
=
False
,
distribution_strategy
=
distribution_strategy
)
# Compute uncased and cased bleu scores.
# Compute uncased and cased bleu scores.
uncased_score
=
compute_bleu
.
bleu_wrapper
(
bleu_ref
,
tmp_filename
,
False
)
uncased_score
=
compute_bleu
.
bleu_wrapper
(
bleu_ref
,
tmp_filename
,
False
)
...
@@ -72,12 +94,31 @@ def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref):
...
@@ -72,12 +94,31 @@ def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref):
return
uncased_score
,
cased_score
return
uncased_score
,
cased_score
def
evaluate_and_log_bleu
(
model
,
bleu_source
,
bleu_ref
,
vocab_file
):
def
evaluate_and_log_bleu
(
model
,
"""Calculate and record the BLEU score."""
params
,
bleu_source
,
bleu_ref
,
vocab_file
,
distribution_strategy
=
None
):
"""Calculate and record the BLEU score.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
vocab_file: A file containing the vocabulary for translation.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
subtokenizer
=
tokenizer
.
Subtokenizer
(
vocab_file
)
subtokenizer
=
tokenizer
.
Subtokenizer
(
vocab_file
)
uncased_score
,
cased_score
=
translate_and_compute_bleu
(
uncased_score
,
cased_score
=
translate_and_compute_bleu
(
model
,
subtokenizer
,
bleu_source
,
bleu_ref
)
model
,
params
,
subtokenizer
,
bleu_source
,
bleu_ref
,
distribution_strategy
)
logging
.
info
(
"Bleu score (uncased): %s"
,
uncased_score
)
logging
.
info
(
"Bleu score (uncased): %s"
,
uncased_score
)
logging
.
info
(
"Bleu score (cased): %s"
,
cased_score
)
logging
.
info
(
"Bleu score (cased): %s"
,
cased_score
)
...
@@ -110,6 +151,9 @@ class TransformerTask(object):
...
@@ -110,6 +151,9 @@ class TransformerTask(object):
params
[
"model_dir"
]
=
flags_obj
.
model_dir
params
[
"model_dir"
]
=
flags_obj
.
model_dir
params
[
"static_batch"
]
=
flags_obj
.
static_batch
params
[
"static_batch"
]
=
flags_obj
.
static_batch
params
[
"max_length"
]
=
flags_obj
.
max_length
params
[
"max_length"
]
=
flags_obj
.
max_length
params
[
"decode_batch_size"
]
=
flags_obj
.
decode_batch_size
params
[
"decode_max_length"
]
=
flags_obj
.
decode_max_length
params
[
"padded_decode"
]
=
flags_obj
.
padded_decode
params
[
"num_parallel_calls"
]
=
(
params
[
"num_parallel_calls"
]
=
(
flags_obj
.
num_parallel_calls
or
tf
.
data
.
experimental
.
AUTOTUNE
)
flags_obj
.
num_parallel_calls
or
tf
.
data
.
experimental
.
AUTOTUNE
)
...
@@ -133,6 +177,7 @@ class TransformerTask(object):
...
@@ -133,6 +177,7 @@ class TransformerTask(object):
num_gpus
=
num_gpus
,
num_gpus
=
num_gpus
,
tpu_address
=
flags_obj
.
tpu
or
""
)
tpu_address
=
flags_obj
.
tpu
or
""
)
if
self
.
use_tpu
:
if
self
.
use_tpu
:
params
[
"num_replicas"
]
=
self
.
distribution_strategy
.
num_replicas_in_sync
if
not
params
[
"static_batch"
]:
if
not
params
[
"static_batch"
]:
raise
ValueError
(
"TPU requires static batch for input data."
)
raise
ValueError
(
"TPU requires static batch for input data."
)
else
:
else
:
...
@@ -306,10 +351,10 @@ class TransformerTask(object):
...
@@ -306,10 +351,10 @@ class TransformerTask(object):
self
.
predict_model
,
self
.
predict_model
,
tf
.
train
.
latest_checkpoint
(
self
.
flags_obj
.
model_dir
))
tf
.
train
.
latest_checkpoint
(
self
.
flags_obj
.
model_dir
))
self
.
predict_model
.
summary
()
self
.
predict_model
.
summary
()
return
evaluate_and_log_bleu
(
self
.
predict_model
,
return
evaluate_and_log_bleu
(
self
.
flags_obj
.
bleu_source
,
self
.
predict_model
,
self
.
params
,
self
.
flags_obj
.
bleu_source
,
self
.
flags_obj
.
bleu_ref
,
self
.
flags_obj
.
bleu_ref
,
self
.
flags_obj
.
vocab_file
,
self
.
flags_obj
.
vocab_fil
e
)
self
.
distribution_strategy
if
self
.
use_tpu
else
Non
e
)
def
predict
(
self
):
def
predict
(
self
):
"""Predicts result from the model."""
"""Predicts result from the model."""
...
...
official/transformer/v2/translate.py
View file @
640ff472
...
@@ -18,11 +18,12 @@ from __future__ import absolute_import
...
@@ -18,11 +18,12 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
values
from
official.transformer.utils
import
tokenizer
from
official.transformer.utils
import
tokenizer
_DECODE_BATCH_SIZE
=
32
_EXTRA_DECODE_LENGTH
=
100
_EXTRA_DECODE_LENGTH
=
100
_BEAM_SIZE
=
4
_BEAM_SIZE
=
4
_ALPHA
=
0.6
_ALPHA
=
0.6
...
@@ -68,23 +69,31 @@ def _trim_and_decode(ids, subtokenizer):
...
@@ -68,23 +69,31 @@ def _trim_and_decode(ids, subtokenizer):
return
subtokenizer
.
decode
(
ids
)
return
subtokenizer
.
decode
(
ids
)
def
translate_file
(
def
translate_file
(
model
,
model
,
subtokenizer
,
input_file
,
output_file
=
None
,
params
,
print_all_translations
=
True
):
subtokenizer
,
input_file
,
output_file
=
None
,
print_all_translations
=
True
,
distribution_strategy
=
None
):
"""Translate lines in file, and save to output file if specified.
"""Translate lines in file, and save to output file if specified.
Args:
Args:
model: Keras model used to generate the translations.
model: A Keras model, used to generate the translations.
subtokenizer: Subtokenizer object for encoding and decoding source and
params: A dictionary, containing the translation related parameters.
translated lines.
subtokenizer: A subtokenizer object, used for encoding and decoding source
input_file: file containing lines to translate
and translated lines.
output_file: file that stores the generated translations.
input_file: A file containing lines to translate.
print_all_translations: If true, all translations are printed to stdout.
output_file: A file that stores the generated translations.
print_all_translations: A bool. If true, all translations are printed to
stdout.
distribution_strategy: A distribution strategy, used to perform inference
directly with tf.function instead of Keras model.predict().
Raises:
Raises:
ValueError: if output file is invalid.
ValueError: if output file is invalid.
"""
"""
batch_size
=
_DECODE_BATCH_SIZE
batch_size
=
params
[
"decode_batch_size"
]
# Read and sort inputs by length. Keep dictionary (original index-->new index
# Read and sort inputs by length. Keep dictionary (original index-->new index
# in sorted list) to write translations in the original order.
# in sorted list) to write translations in the original order.
...
@@ -101,24 +110,59 @@ def translate_file(
...
@@ -101,24 +110,59 @@ def translate_file(
if
j
+
i
*
batch_size
<
total_samples
if
j
+
i
*
batch_size
<
total_samples
]
]
lines
=
[
_encode_and_add_eos
(
l
,
subtokenizer
)
for
l
in
lines
]
lines
=
[
_encode_and_add_eos
(
l
,
subtokenizer
)
for
l
in
lines
]
if
distribution_strategy
:
for
j
in
range
(
batch_size
-
len
(
lines
)):
lines
.
append
([
tokenizer
.
EOS_ID
])
batch
=
tf
.
keras
.
preprocessing
.
sequence
.
pad_sequences
(
batch
=
tf
.
keras
.
preprocessing
.
sequence
.
pad_sequences
(
lines
,
dtype
=
"int64"
,
padding
=
"post"
)
lines
,
maxlen
=
params
[
"decode_max_length"
],
dtype
=
"int32"
,
padding
=
"post"
)
tf
.
compat
.
v1
.
logging
.
info
(
"Decoding batch %d out of %d."
,
i
,
tf
.
compat
.
v1
.
logging
.
info
(
"Decoding batch %d out of %d."
,
i
,
num_decode_batches
)
num_decode_batches
)
yield
batch
yield
batch
@
tf
.
function
def
predict_step
(
inputs
):
"""Decoding step function for TPU runs."""
def
_step_fn
(
inputs
):
"""Per replica step function."""
val_outputs
,
_
=
model
([
inputs
],
training
=
False
)
return
val_outputs
return
distribution_strategy
.
experimental_run_v2
(
_step_fn
,
args
=
(
inputs
,))
translations
=
[]
translations
=
[]
if
distribution_strategy
:
num_replicas
=
distribution_strategy
.
num_replicas_in_sync
local_batch_size
=
params
[
"decode_batch_size"
]
//
num_replicas
for
i
,
text
in
enumerate
(
input_generator
()):
for
i
,
text
in
enumerate
(
input_generator
()):
val_outputs
,
_
=
model
.
predict
(
text
)
if
distribution_strategy
:
text
=
np
.
reshape
(
text
,
[
num_replicas
,
local_batch_size
,
-
1
])
text
=
[
tf
.
convert_to_tensor
(
per_replica_text
)
for
per_replica_text
in
text
]
# pylint: disable=protected-access
text
=
values
.
PerReplica
(
distribution_strategy
.
extended
.
_device_map
,
text
)
# pylint: enable=protected-access
val_outputs
=
distribution_strategy
.
experimental_local_results
(
predict_step
(
text
))
val_outputs
=
np
.
reshape
(
[
val_output
.
numpy
()
for
val_output
in
val_outputs
],
[
params
[
"decode_batch_size"
],
-
1
])
else
:
val_outputs
,
_
=
model
.
predict
(
text
)
length
=
len
(
val_outputs
)
length
=
len
(
val_outputs
)
for
j
in
range
(
length
):
for
j
in
range
(
length
):
translation
=
_trim_and_decode
(
val_outputs
[
j
],
subtokenizer
)
if
j
+
i
*
batch_size
<
total_samples
:
translations
.
append
(
translation
)
translation
=
_trim_and_decode
(
val_outputs
[
j
],
subtokenizer
)
if
print_all_translations
:
translations
.
append
(
translation
)
tf
.
compat
.
v1
.
logging
.
info
(
if
print_all_translations
:
"Translating:
\n\t
Input: %s
\n\t
Output: %s"
%
tf
.
compat
.
v1
.
logging
.
info
(
(
sorted_inputs
[
j
+
i
*
batch_size
],
translation
))
"Translating:
\n\t
Input: %s
\n\t
Output: %s"
%
(
sorted_inputs
[
j
+
i
*
batch_size
],
translation
))
# Write translations in the order they appeared in the original file.
# Write translations in the order they appeared in the original file.
if
output_file
is
not
None
:
if
output_file
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment