Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2e6ccf02
Commit
2e6ccf02
authored
Jul 17, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jul 17, 2020
Browse files
Move beam search to nlp/modeling/ops
PiperOrigin-RevId: 321817352
parent
32ae035e
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
722 additions
and
765 deletions
+722
-765
official/nlp/modeling/ops/__init__.py
official/nlp/modeling/ops/__init__.py
+1
-0
official/nlp/modeling/ops/beam_search.py
official/nlp/modeling/ops/beam_search.py
+708
-0
official/nlp/modeling/ops/beam_search_test.py
official/nlp/modeling/ops/beam_search_test.py
+5
-29
official/nlp/nhnet/models.py
official/nlp/nhnet/models.py
+1
-1
official/nlp/transformer/beam_search.py
official/nlp/transformer/beam_search.py
+0
-132
official/nlp/transformer/beam_search_v1.py
official/nlp/transformer/beam_search_v1.py
+6
-602
official/nlp/transformer/transformer.py
official/nlp/transformer/transformer.py
+1
-1
No files found.
official/nlp/modeling/ops/__init__.py
0 → 100644
View file @
2e6ccf02
official/nlp/modeling/ops/beam_search.py
0 → 100644
View file @
2e6ccf02
This diff is collapsed.
Click to expand it.
official/nlp/
transformer
/beam_search_
v1_
test.py
→
official/nlp/
modeling/ops
/beam_search_test.py
View file @
2e6ccf02
...
...
@@ -14,33 +14,19 @@
# ==============================================================================
"""Test beam search helper methods."""
import
tensorflow
.compat.v1
as
tf
import
tensorflow
as
tf
from
official.nlp.
transformer
import
beam_search
_v1
as
beam_search
from
official.nlp.
modeling.ops
import
beam_search
class
BeamSearchHelperTests
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
BeamSearchHelperTests
,
self
).
setUp
()
tf
.
compat
.
v1
.
disable_eager_execution
()
def
test_expand_to_beam_size
(
self
):
x
=
tf
.
ones
([
7
,
4
,
2
,
5
])
x
=
beam_search
.
_expand_to_beam_size
(
x
,
3
)
with
self
.
session
()
as
sess
:
shape
=
sess
.
run
(
tf
.
shape
(
x
))
shape
=
tf
.
shape
(
x
)
self
.
assertAllEqual
([
7
,
3
,
4
,
2
,
5
],
shape
)
def
test_shape_list
(
self
):
y
=
tf
.
compat
.
v1
.
placeholder
(
dtype
=
tf
.
int32
,
shape
=
[])
x
=
tf
.
ones
([
7
,
y
,
2
,
5
])
shape
=
beam_search
.
_shape_list
(
x
)
self
.
assertIsInstance
(
shape
[
0
],
int
)
self
.
assertIsInstance
(
shape
[
1
],
tf
.
Tensor
)
self
.
assertIsInstance
(
shape
[
2
],
int
)
self
.
assertIsInstance
(
shape
[
3
],
int
)
def
test_get_shape_keep_last_dim
(
self
):
y
=
tf
.
constant
(
4.0
)
x
=
tf
.
ones
([
7
,
tf
.
cast
(
tf
.
sqrt
(
y
),
tf
.
int32
),
2
,
5
])
...
...
@@ -51,16 +37,12 @@ class BeamSearchHelperTests(tf.test.TestCase):
def
test_flatten_beam_dim
(
self
):
x
=
tf
.
ones
([
7
,
4
,
2
,
5
])
x
=
beam_search
.
_flatten_beam_dim
(
x
)
with
self
.
session
()
as
sess
:
shape
=
sess
.
run
(
tf
.
shape
(
x
))
self
.
assertAllEqual
([
28
,
2
,
5
],
shape
)
self
.
assertAllEqual
([
28
,
2
,
5
],
tf
.
shape
(
x
))
def
test_unflatten_beam_dim
(
self
):
x
=
tf
.
ones
([
28
,
2
,
5
])
x
=
beam_search
.
_unflatten_beam_dim
(
x
,
7
,
4
)
with
self
.
session
()
as
sess
:
shape
=
sess
.
run
(
tf
.
shape
(
x
))
self
.
assertAllEqual
([
7
,
4
,
2
,
5
],
shape
)
self
.
assertAllEqual
([
7
,
4
,
2
,
5
],
tf
.
shape
(
x
))
def
test_gather_beams
(
self
):
x
=
tf
.
reshape
(
tf
.
range
(
24
),
[
2
,
3
,
4
])
...
...
@@ -73,9 +55,6 @@ class BeamSearchHelperTests(tf.test.TestCase):
# [20 21 22 23]]]
y
=
beam_search
.
_gather_beams
(
x
,
[[
1
,
2
],
[
0
,
2
]],
2
,
2
)
with
self
.
session
()
as
sess
:
y
=
sess
.
run
(
y
)
self
.
assertAllEqual
([[[
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
]],
[[
12
,
13
,
14
,
15
],
...
...
@@ -87,9 +66,6 @@ class BeamSearchHelperTests(tf.test.TestCase):
x_scores
=
[[
0
,
1
,
1
],
[
1
,
0
,
1
]]
y
=
beam_search
.
_gather_topk_beams
(
x
,
x_scores
,
2
,
2
)
with
self
.
session
()
as
sess
:
y
=
sess
.
run
(
y
)
self
.
assertAllEqual
([[[
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
]],
[[
12
,
13
,
14
,
15
],
...
...
official/nlp/nhnet/models.py
View file @
2e6ccf02
...
...
@@ -31,7 +31,7 @@ from official.nlp.modeling.layers import multi_channel_attention
from
official.nlp.nhnet
import
configs
from
official.nlp.nhnet
import
decoder
from
official.nlp.nhnet
import
utils
from
official.nlp.
transformer
import
beam_search
from
official.nlp.
modeling.ops
import
beam_search
def
embedding_linear
(
embedding_matrix
,
x
):
...
...
official/nlp/transformer/beam_search.py
deleted
100644 → 0
View file @
32ae035e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search in TF v2."""
import
tensorflow
as
tf
from
official.nlp.transformer
import
beam_search_v1
as
v1
_StateKeys
=
v1
.
_StateKeys
# pylint: disable=protected-access
class
SequenceBeamSearchV2
(
v1
.
SequenceBeamSearch
):
"""Implementation of beam search loop in v2."""
def
search
(
self
,
initial_ids
,
initial_cache
):
"""Beam search for sequences with highest scores."""
state
,
state_shapes
=
self
.
_create_initial_state
(
initial_ids
,
initial_cache
)
finished_state
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
while_loop
(
self
.
_continue_search
,
self
.
_search_step
,
loop_vars
=
[
state
],
shape_invariants
=
[
state_shapes
],
parallel_iterations
=
1
))
finished_state
=
finished_state
[
0
]
alive_seq
=
finished_state
[
_StateKeys
.
ALIVE_SEQ
]
alive_log_probs
=
finished_state
[
_StateKeys
.
ALIVE_LOG_PROBS
]
finished_seq
=
finished_state
[
_StateKeys
.
FINISHED_SEQ
]
finished_scores
=
finished_state
[
_StateKeys
.
FINISHED_SCORES
]
finished_flags
=
finished_state
[
_StateKeys
.
FINISHED_FLAGS
]
# 2.0 changes tf.where behavior. Should make parameters broadcastable.
finished_cond
=
tf
.
reduce_any
(
finished_flags
,
1
,
name
=
"finished_cond"
)
seq_cond
=
_expand_to_same_rank
(
finished_cond
,
finished_seq
)
score_cond
=
_expand_to_same_rank
(
finished_cond
,
finished_scores
)
# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# item.
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
return
finished_seq
,
finished_scores
def
sequence_beam_search
(
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
=
False
,
dtype
=
"float32"
):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of tokens.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
def
_expand_to_same_rank
(
tensor
,
target
):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target.
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if
tensor
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for tensor shape, but got None."
)
if
target
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for target shape, but got None."
)
with
tf
.
name_scope
(
"expand_rank"
):
diff_rank
=
target
.
shape
.
rank
-
tensor
.
shape
.
rank
for
_
in
range
(
diff_rank
):
tensor
=
tf
.
expand_dims
(
tensor
,
-
1
)
return
tensor
official/nlp/transformer/beam_search_v1.py
View file @
2e6ccf02
This diff is collapsed.
Click to expand it.
official/nlp/transformer/transformer.py
View file @
2e6ccf02
...
...
@@ -23,8 +23,8 @@ from __future__ import print_function
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
position_embedding
from
official.nlp.modeling.ops
import
beam_search
from
official.nlp.transformer
import
attention_layer
from
official.nlp.transformer
import
beam_search
from
official.nlp.transformer
import
embedding_layer
from
official.nlp.transformer
import
ffn_layer
from
official.nlp.transformer
import
metrics
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment