Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
df2e30cd
Commit
df2e30cd
authored
May 11, 2021
by
Bruce Fontaine
Committed by
A. Unique TensorFlower
May 11, 2021
Browse files
Move _gather_beams into the SequenceBeamSearch class.
PiperOrigin-RevId: 373191989
parent
a82f0b56
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
61 deletions
+48
-61
official/nlp/modeling/ops/beam_search.py
official/nlp/modeling/ops/beam_search.py
+47
-51
official/nlp/modeling/ops/beam_search_test.py
official/nlp/modeling/ops/beam_search_test.py
+1
-10
No files found.
official/nlp/modeling/ops/beam_search.py
View file @
df2e30cd
...
@@ -218,7 +218,7 @@ class SequenceBeamSearch(tf.Module):
...
@@ -218,7 +218,7 @@ class SequenceBeamSearch(tf.Module):
# Extract the alive sequences that generate the highest log probabilities
# Extract the alive sequences that generate the highest log probabilities
# after being extended.
# after being extended.
topk_beam_indices
=
topk_indices
//
self
.
vocab_size
topk_beam_indices
=
topk_indices
//
self
.
vocab_size
topk_seq
,
new_cache
=
_gather_beams
([
alive_seq
,
new_cache
],
topk_seq
,
new_cache
=
self
.
_gather_beams
([
alive_seq
,
new_cache
],
topk_beam_indices
,
batch_size
,
topk_beam_indices
,
batch_size
,
beams_to_keep
)
beams_to_keep
)
...
@@ -259,9 +259,10 @@ class SequenceBeamSearch(tf.Module):
...
@@ -259,9 +259,10 @@ class SequenceBeamSearch(tf.Module):
new_log_probs
+=
tf
.
cast
(
new_finished_flags
,
new_log_probs
+=
tf
.
cast
(
new_finished_flags
,
self
.
dtype
)
*
-
inf
(
self
.
dtype
)
self
.
dtype
)
*
-
inf
(
self
.
dtype
)
top_alive_seq
,
top_alive_log_probs
,
top_alive_cache
=
_gather_topk_beams
(
_
,
topk_indexes
=
tf
.
nn
.
top_k
(
new_log_probs
,
k
=
self
.
beam_size
)
[
new_seq
,
new_log_probs
,
new_cache
],
new_log_probs
,
batch_size
,
top_alive_seq
,
top_alive_log_probs
,
top_alive_cache
=
(
self
.
beam_size
)
self
.
_gather_beams
([
new_seq
,
new_log_probs
,
new_cache
],
topk_indexes
,
batch_size
,
self
.
beam_size
))
return
{
return
{
_StateKeys
.
ALIVE_SEQ
:
top_alive_seq
,
_StateKeys
.
ALIVE_SEQ
:
top_alive_seq
,
...
@@ -316,9 +317,10 @@ class SequenceBeamSearch(tf.Module):
...
@@ -316,9 +317,10 @@ class SequenceBeamSearch(tf.Module):
finished_flags
=
tf
.
concat
([
finished_flags
,
new_finished_flags
],
axis
=
1
)
finished_flags
=
tf
.
concat
([
finished_flags
,
new_finished_flags
],
axis
=
1
)
# Return the finished sequences with the best scores.
# Return the finished sequences with the best scores.
_
,
topk_indexes
=
tf
.
nn
.
top_k
(
finished_scores
,
k
=
self
.
beam_size
)
top_finished_seq
,
top_finished_scores
,
top_finished_flags
=
(
top_finished_seq
,
top_finished_scores
,
top_finished_flags
=
(
_gather_
topk_
beams
([
finished_seq
,
finished_scores
,
finished_flags
],
self
.
_gather_beams
([
finished_seq
,
finished_scores
,
finished_flags
],
finished_scor
es
,
batch_size
,
self
.
beam_size
))
topk_index
es
,
batch_size
,
self
.
beam_size
))
return
{
return
{
_StateKeys
.
FINISHED_SEQ
:
top_finished_seq
,
_StateKeys
.
FINISHED_SEQ
:
top_finished_seq
,
...
@@ -538,6 +540,43 @@ class SequenceBeamSearch(tf.Module):
...
@@ -538,6 +540,43 @@ class SequenceBeamSearch(tf.Module):
not_at_max_decode_length
,
not_at_max_decode_length
,
tf
.
logical_not
(
worst_finished_score_better_than_best_alive_score
))
tf
.
logical_not
(
worst_finished_score_better_than_best_alive_score
))
@
staticmethod
def
_gather_beams
(
nested
,
beam_indices
,
batch_size
,
new_beam_size
):
"""Gather beams from nested structure of tensors.
Each tensor in nested represents a batch of beams, where beam refers to a
single search state (beam search involves searching through multiple states
in parallel).
This function is used to gather the top beams, specified by
beam_indices, from the nested tensors.
Args:
nested: Nested structure (tensor, list, tuple or dict) containing tensors
with shape [batch_size, beam_size, ...].
beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each
value in beam_indices must be between [0, beam_size), and are not
necessarily unique.
batch_size: int size of batch
new_beam_size: int number of beams to be pulled from the nested tensors.
Returns:
Nested structure containing tensors with shape
[batch_size, new_beam_size, ...]
"""
# Computes the i'th coodinate that contains the batch index for gather_nd.
# Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..].
batch_pos
=
tf
.
range
(
batch_size
*
new_beam_size
)
//
new_beam_size
batch_pos
=
tf
.
reshape
(
batch_pos
,
[
batch_size
,
new_beam_size
])
# Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor
# with shape [batch_size, beam_size, 2], where the last dimension contains
# the (i, j) gathering coordinates.
coordinates
=
tf
.
stack
([
batch_pos
,
beam_indices
],
axis
=
2
)
return
tf
.
nest
.
map_structure
(
lambda
state
:
tf
.
gather_nd
(
state
,
coordinates
),
nested
)
def
sequence_beam_search
(
symbols_to_logits_fn
,
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_ids
,
...
@@ -663,46 +702,3 @@ def _unflatten_beam_dim(tensor, batch_size, beam_size):
...
@@ -663,46 +702,3 @@ def _unflatten_beam_dim(tensor, batch_size, beam_size):
shape
=
_shape_list
(
tensor
)
shape
=
_shape_list
(
tensor
)
new_shape
=
[
batch_size
,
beam_size
]
+
shape
[
1
:]
new_shape
=
[
batch_size
,
beam_size
]
+
shape
[
1
:]
return
tf
.
reshape
(
tensor
,
new_shape
)
return
tf
.
reshape
(
tensor
,
new_shape
)
def
_gather_beams
(
nested
,
beam_indices
,
batch_size
,
new_beam_size
):
"""Gather beams from nested structure of tensors.
Each tensor in nested represents a batch of beams, where beam refers to a
single search state (beam search involves searching through multiple states
in parallel).
This function is used to gather the top beams, specified by
beam_indices, from the nested tensors.
Args:
nested: Nested structure (tensor, list, tuple or dict) containing tensors
with shape [batch_size, beam_size, ...].
beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each
value in beam_indices must be between [0, beam_size), and are not
necessarily unique.
batch_size: int size of batch
new_beam_size: int number of beams to be pulled from the nested tensors.
Returns:
Nested structure containing tensors with shape
[batch_size, new_beam_size, ...]
"""
# Computes the i'th coodinate that contains the batch index for gather_nd.
# Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..].
batch_pos
=
tf
.
range
(
batch_size
*
new_beam_size
)
//
new_beam_size
batch_pos
=
tf
.
reshape
(
batch_pos
,
[
batch_size
,
new_beam_size
])
# Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor
# with shape [batch_size, beam_size, 2], where the last dimension contains
# the (i, j) gathering coordinates.
coordinates
=
tf
.
stack
([
batch_pos
,
beam_indices
],
axis
=
2
)
return
tf
.
nest
.
map_structure
(
lambda
state
:
tf
.
gather_nd
(
state
,
coordinates
),
nested
)
def
_gather_topk_beams
(
nested
,
score_or_log_prob
,
batch_size
,
beam_size
):
"""Gather top beams from nested structure."""
_
,
topk_indexes
=
tf
.
nn
.
top_k
(
score_or_log_prob
,
k
=
beam_size
)
return
_gather_beams
(
nested
,
topk_indexes
,
batch_size
,
beam_size
)
official/nlp/modeling/ops/beam_search_test.py
View file @
df2e30cd
...
@@ -54,16 +54,7 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
...
@@ -54,16 +54,7 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
# [16 17 18 19]
# [16 17 18 19]
# [20 21 22 23]]]
# [20 21 22 23]]]
y
=
beam_search
.
_gather_beams
(
x
,
[[
1
,
2
],
[
0
,
2
]],
2
,
2
)
y
=
beam_search
.
SequenceBeamSearch
.
_gather_beams
(
x
,
[[
1
,
2
],
[
0
,
2
]],
2
,
2
)
self
.
assertAllEqual
(
[[[
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
]],
[[
12
,
13
,
14
,
15
],
[
20
,
21
,
22
,
23
]]],
y
)
def
test_gather_topk_beams
(
self
):
x
=
tf
.
reshape
(
tf
.
range
(
24
),
[
2
,
3
,
4
])
x_scores
=
[[
0
,
1
,
1
],
[
1
,
0
,
1
]]
y
=
beam_search
.
_gather_topk_beams
(
x
,
x_scores
,
2
,
2
)
self
.
assertAllEqual
(
self
.
assertAllEqual
(
[[[
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
]],
[[
12
,
13
,
14
,
15
],
[
20
,
21
,
22
,
23
]]],
[[[
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
]],
[[
12
,
13
,
14
,
15
],
[
20
,
21
,
22
,
23
]]],
y
)
y
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment