Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
83704fde
"references/vscode:/vscode.git/clone" did not exist on "601ce5fc7a313641d82553625d46f26a93f10fd9"
Commit
83704fde
authored
Sep 21, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 332953126
parent
f620075d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
2 deletions
+38
-2
official/nlp/modeling/ops/beam_search_test.py
official/nlp/modeling/ops/beam_search_test.py
+38
-2
No files found.
official/nlp/modeling/ops/beam_search_test.py
View file @
83704fde
...
...
@@ -14,12 +14,13 @@
# ==============================================================================
"""Test beam search helper methods."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.modeling.ops
import
beam_search
class
BeamSearch
Helper
Tests
(
tf
.
test
.
TestCase
):
class
BeamSearchTests
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_expand_to_beam_size
(
self
):
x
=
tf
.
ones
([
7
,
4
,
2
,
5
])
...
...
@@ -67,6 +68,41 @@ class BeamSearchHelperTests(tf.test.TestCase):
[[[
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
]],
[[
12
,
13
,
14
,
15
],
[
20
,
21
,
22
,
23
]]],
y
)
@
parameterized
.
named_parameters
([
(
'padded_decode_true'
,
True
),
(
'padded_decode_false'
,
False
),
])
def
test_sequence_beam_search
(
self
,
padded_decode
):
# batch_size*beam_size, max_decode_length, vocab_size
probabilities
=
tf
.
constant
([[[
0.2
,
0.7
,
0.1
],
[
0.5
,
0.3
,
0.2
],
[
0.1
,
0.8
,
0.1
]],
[[
0.1
,
0.8
,
0.1
],
[
0.3
,
0.4
,
0.3
],
[
0.2
,
0.1
,
0.7
]]])
# batch_size, max_decode_length, num_heads, embed_size per head
x
=
tf
.
zeros
([
1
,
3
,
2
,
32
],
dtype
=
tf
.
float32
)
cache
=
{
'layer_%d'
%
layer
:
{
'k'
:
x
,
'v'
:
x
}
for
layer
in
range
(
2
)}
if
__name__
==
"__main__"
:
def
_get_test_symbols_to_logits_fn
():
"""Test function that returns logits for next token."""
def
symbols_to_logits_fn
(
_
,
i
,
cache
):
logits
=
tf
.
cast
(
probabilities
[:,
i
,
:],
tf
.
float32
)
return
logits
,
cache
return
symbols_to_logits_fn
predictions
,
_
=
beam_search
.
sequence_beam_search
(
symbols_to_logits_fn
=
_get_test_symbols_to_logits_fn
(),
initial_ids
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
),
initial_cache
=
cache
,
vocab_size
=
3
,
beam_size
=
2
,
alpha
=
0.6
,
max_decode_length
=
3
,
eos_id
=
9
,
padded_decode
=
padded_decode
,
dtype
=
tf
.
float32
)
self
.
assertAllEqual
([[[
0
,
1
,
0
,
1
],
[
0
,
1
,
1
,
2
]]],
predictions
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment