Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
913640d4
Commit
913640d4
authored
Dec 16, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Dec 16, 2019
Browse files
Internal change
PiperOrigin-RevId: 285765110
parent
722d9e57
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
14 deletions
+20
-14
official/transformer/model/beam_search.py
official/transformer/model/beam_search.py
+20
-14
No files found.
official/transformer/model/beam_search.py
View file @
913640d4
...
@@ -323,13 +323,16 @@ class SequenceBeamSearch(object):
...
@@ -323,13 +323,16 @@ class SequenceBeamSearch(object):
new state dictionary.
new state dictionary.
"""
"""
# Grow alive sequences by one token.
# Grow alive sequences by one token.
new_seq
,
new_log_probs
,
new_cache
=
self
.
_grow_alive_seq
(
state
)
new_seq
,
new_log_probs
,
topk_ids
,
new_cache
=
self
.
_grow_alive_seq
(
state
)
new_finished_flags
=
tf
.
equal
(
topk_ids
,
self
.
eos_id
)
# Collect top beam_size alive sequences
# Collect top beam_size alive sequences
alive_state
=
self
.
_get_new_alive_state
(
new_seq
,
new_log_probs
,
new_cache
)
alive_state
=
self
.
_get_new_alive_state
(
new_seq
,
new_log_probs
,
new_finished_flags
,
new_cache
)
# Combine newly finished sequences with existing finished sequences, and
# Combine newly finished sequences with existing finished sequences, and
# collect the top k scoring sequences.
# collect the top k scoring sequences.
finished_state
=
self
.
_get_new_finished_state
(
state
,
new_seq
,
new_log_probs
)
finished_state
=
self
.
_get_new_finished_state
(
state
,
new_seq
,
new_log_probs
,
new_finished_flags
)
# Increment loop index and create new state dictionary
# Increment loop index and create new state dictionary
new_state
=
{
_StateKeys
.
CUR_INDEX
:
state
[
_StateKeys
.
CUR_INDEX
]
+
1
}
new_state
=
{
_StateKeys
.
CUR_INDEX
:
state
[
_StateKeys
.
CUR_INDEX
]
+
1
}
...
@@ -407,18 +410,20 @@ class SequenceBeamSearch(object):
...
@@ -407,18 +410,20 @@ class SequenceBeamSearch(object):
tf
.
expand_dims
(
topk_ids
,
axis
=
0
))
tf
.
expand_dims
(
topk_ids
,
axis
=
0
))
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
1
,
2
,
0
])
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
1
,
2
,
0
])
else
:
else
:
topk_ids
=
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)
topk_seq
=
tf
.
concat
([
topk_seq
,
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)],
axis
=
2
)
topk_seq
=
tf
.
concat
([
topk_seq
,
topk_ids
],
axis
=
2
)
return
topk_seq
,
topk_log_probs
,
topk_ids
,
new_cache
return
topk_seq
,
topk_log_probs
,
new_cache
def
_get_new_alive_state
(
self
,
new_seq
,
new_log_probs
,
new_cache
):
def
_get_new_alive_state
(
self
,
new_seq
,
new_log_probs
,
new_finished_flags
,
new_cache
):
"""Gather the top k sequences that are still alive.
"""Gather the top k sequences that are still alive.
Args:
Args:
new_seq: New sequences generated by growing the current alive sequences
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
new_log_probs: Log probabilities of new sequences
new_log_probs: Log probabilities of new sequences float32 tensor with
float32 tensor with shape [batch_size, beam_size]
shape [batch_size, beam_size]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
new_cache: Dict of cached values for each sequence.
new_cache: Dict of cached values for each sequence.
Returns:
Returns:
...
@@ -428,7 +433,6 @@ class SequenceBeamSearch(object):
...
@@ -428,7 +433,6 @@ class SequenceBeamSearch(object):
Dict cache storing decoder states for top alive sequences}
Dict cache storing decoder states for top alive sequences}
"""
"""
# To prevent finished sequences from being considered, set log probs to -inf
# To prevent finished sequences from being considered, set log probs to -inf
new_finished_flags
=
tf
.
equal
(
new_seq
[:,
:,
-
1
],
self
.
eos_id
)
new_log_probs
+=
tf
.
cast
(
new_finished_flags
,
self
.
dtype
)
*
-
inf
(
self
.
dtype
)
new_log_probs
+=
tf
.
cast
(
new_finished_flags
,
self
.
dtype
)
*
-
inf
(
self
.
dtype
)
top_alive_seq
,
top_alive_log_probs
,
top_alive_cache
=
_gather_topk_beams
(
top_alive_seq
,
top_alive_log_probs
,
top_alive_cache
=
_gather_topk_beams
(
...
@@ -441,15 +445,18 @@ class SequenceBeamSearch(object):
...
@@ -441,15 +445,18 @@ class SequenceBeamSearch(object):
_StateKeys
.
ALIVE_CACHE
:
top_alive_cache
_StateKeys
.
ALIVE_CACHE
:
top_alive_cache
}
}
def
_get_new_finished_state
(
self
,
state
,
new_seq
,
new_log_probs
):
def
_get_new_finished_state
(
self
,
state
,
new_seq
,
new_log_probs
,
new_finished_flags
):
"""Combine new and old finished sequences, and gather the top k sequences.
"""Combine new and old finished sequences, and gather the top k sequences.
Args:
Args:
state: A dictionary with the current loop state.
state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, beam_size, i + 1]
int32 tensor with shape [batch_size, beam_size, i + 1]
new_log_probs: Log probabilities of new sequences
new_log_probs: Log probabilities of new sequences float32 tensor with
float32 tensor with shape [batch_size, beam_size]
shape [batch_size, beam_size]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
Returns:
Returns:
Dictionary with finished keys from _StateKeys:
Dictionary with finished keys from _StateKeys:
...
@@ -476,7 +483,6 @@ class SequenceBeamSearch(object):
...
@@ -476,7 +483,6 @@ class SequenceBeamSearch(object):
new_scores
=
new_log_probs
/
length_norm
new_scores
=
new_log_probs
/
length_norm
# Set the scores of the still-alive seq in new_seq to large negative values.
# Set the scores of the still-alive seq in new_seq to large negative values.
new_finished_flags
=
tf
.
equal
(
new_seq
[:,
:,
-
1
],
self
.
eos_id
)
new_scores
+=
((
1.
-
tf
.
cast
(
new_finished_flags
,
self
.
dtype
))
*
new_scores
+=
((
1.
-
tf
.
cast
(
new_finished_flags
,
self
.
dtype
))
*
-
inf
(
self
.
dtype
))
-
inf
(
self
.
dtype
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment