Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a15ebc46
Commit
a15ebc46
authored
Jun 16, 2022
by
A. Unique TensorFlower
Browse files
Allow passing an optional name for the decoding loop tensors.
PiperOrigin-RevId: 455384659
parent
3c77e654
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
24 deletions
+38
-24
official/nlp/modeling/ops/beam_search.py
official/nlp/modeling/ops/beam_search.py
+18
-12
official/nlp/modeling/ops/beam_search_test.py
official/nlp/modeling/ops/beam_search_test.py
+7
-4
official/nlp/modeling/ops/decoding_module.py
official/nlp/modeling/ops/decoding_module.py
+7
-6
official/nlp/modeling/ops/sampling_module.py
official/nlp/modeling/ops/sampling_module.py
+6
-2
No files found.
official/nlp/modeling/ops/beam_search.py
View file @
a15ebc46
...
@@ -107,18 +107,18 @@ class SequenceBeamSearch(tf.Module):
...
@@ -107,18 +107,18 @@ class SequenceBeamSearch(tf.Module):
max_decode_length
,
max_decode_length
,
eos_id
,
eos_id
,
padded_decode
,
padded_decode
,
dtype
=
tf
.
float32
):
dtype
=
tf
.
float32
,
decoding_name
=
None
):
"""Initialize sequence beam search.
"""Initialize sequence beam search.
Args:
Args:
symbols_to_logits_fn: A function to provide logits, which is the
symbols_to_logits_fn: A function to provide logits, which is the interface
interface to the Transformer model. The passed in arguments are: ids ->
to the Transformer model. The passed in arguments are: ids -> A tensor
A tensor with shape [batch_size * beam_size, index]. index -> A
with shape [batch_size * beam_size, index]. index -> A scalar. cache ->
scalar. cache -> A nested dictionary of tensors [batch_size *
A nested dictionary of tensors [batch_size * beam_size, ...]. The
beam_size, ...].
function must return a tuple of logits and the updated cache: logits ->
The function must return a tuple of logits and the updated cache: logits
A tensor with shape [batch * beam_size, vocab_size]. updated cache -> A
-> A tensor with shape [batch * beam_size, vocab_size]. updated cache
nested dictionary with the same structure as the input cache.
-> A nested dictionary with the same structure as the input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
computation.
beam_size: An integer, number of beams for beam search.
beam_size: An integer, number of beams for beam search.
...
@@ -130,6 +130,7 @@ class SequenceBeamSearch(tf.Module):
...
@@ -130,6 +130,7 @@ class SequenceBeamSearch(tf.Module):
for beam search.
for beam search.
dtype: A tensorflow data type used for score computation. The default is
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
"""
"""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
...
@@ -139,6 +140,7 @@ class SequenceBeamSearch(tf.Module):
...
@@ -139,6 +140,7 @@ class SequenceBeamSearch(tf.Module):
self
.
eos_id
=
eos_id
self
.
eos_id
=
eos_id
self
.
padded_decode
=
padded_decode
self
.
padded_decode
=
padded_decode
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
decoding_name
=
decoding_name
def
search
(
self
,
initial_ids
,
initial_cache
):
def
search
(
self
,
initial_ids
,
initial_cache
):
"""Beam search for sequences with highest scores.
"""Beam search for sequences with highest scores.
...
@@ -370,7 +372,8 @@ class SequenceBeamSearch(tf.Module):
...
@@ -370,7 +372,8 @@ class SequenceBeamSearch(tf.Module):
_search_step
,
_search_step
,
loop_vars
=
[
state
],
loop_vars
=
[
state
],
shape_invariants
=
[
state_shapes
],
shape_invariants
=
[
state_shapes
],
parallel_iterations
=
1
))
parallel_iterations
=
1
,
name
=
self
.
decoding_name
))
finished_state
=
finished_state
[
0
]
finished_state
=
finished_state
[
0
]
return
self
.
_process_finished_state
(
finished_state
)
return
self
.
_process_finished_state
(
finished_state
)
...
@@ -587,7 +590,8 @@ def sequence_beam_search(symbols_to_logits_fn,
...
@@ -587,7 +590,8 @@ def sequence_beam_search(symbols_to_logits_fn,
max_decode_length
,
max_decode_length
,
eos_id
,
eos_id
,
padded_decode
=
False
,
padded_decode
=
False
,
dtype
=
"float32"
):
dtype
=
"float32"
,
decoding_name
=
None
):
"""Search for sequence of subtoken ids with the largest probability.
"""Search for sequence of subtoken ids with the largest probability.
Args:
Args:
...
@@ -612,13 +616,15 @@ def sequence_beam_search(symbols_to_logits_fn,
...
@@ -612,13 +616,15 @@ def sequence_beam_search(symbols_to_logits_fn,
beam search.
beam search.
dtype: A tensorflow data type used for score computation. The default is
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
Returns:
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
sequence scores [batch_size, beam_size]
"""
"""
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
beam_size
,
alpha
,
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
)
max_decode_length
,
eos_id
,
padded_decode
,
dtype
,
decoding_name
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
...
official/nlp/modeling/ops/beam_search_test.py
View file @
a15ebc46
...
@@ -60,10 +60,12 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
...
@@ -60,10 +60,12 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
y
)
y
)
@
parameterized
.
named_parameters
([
@
parameterized
.
named_parameters
([
(
'padded_decode_true'
,
True
),
(
'padded_decode_true_with_name'
,
True
,
'decoding'
),
(
'padded_decode_false'
,
False
),
(
'padded_decode_false_with_name'
,
False
,
'decoding'
),
(
'padded_decode_true_without_name'
,
True
,
None
),
(
'padded_decode_false_without_name'
,
False
,
None
),
])
])
def
test_sequence_beam_search
(
self
,
padded_decode
):
def
test_sequence_beam_search
(
self
,
padded_decode
,
name
):
# batch_size*beam_size, max_decode_length, vocab_size
# batch_size*beam_size, max_decode_length, vocab_size
probabilities
=
tf
.
constant
([[[
0.2
,
0.7
,
0.1
],
[
0.5
,
0.3
,
0.2
],
probabilities
=
tf
.
constant
([[[
0.2
,
0.7
,
0.1
],
[
0.5
,
0.3
,
0.2
],
[
0.1
,
0.8
,
0.1
]],
[
0.1
,
0.8
,
0.1
]],
...
@@ -91,7 +93,8 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
...
@@ -91,7 +93,8 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
max_decode_length
=
3
,
max_decode_length
=
3
,
eos_id
=
9
,
eos_id
=
9
,
padded_decode
=
padded_decode
,
padded_decode
=
padded_decode
,
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
,
decoding_name
=
name
)
self
.
assertAllEqual
([[[
0
,
1
,
0
,
1
],
[
0
,
1
,
1
,
2
]]],
predictions
)
self
.
assertAllEqual
([[[
0
,
1
,
0
,
1
],
[
0
,
1
,
1
,
2
]]],
predictions
)
...
...
official/nlp/modeling/ops/decoding_module.py
View file @
a15ebc46
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""
import
abc
import
abc
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -108,7 +108,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
...
@@ -108,7 +108,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
def
__init__
(
self
,
def
__init__
(
self
,
length_normalization_fn
:
Callable
[[
int
,
tf
.
DType
],
float
],
length_normalization_fn
:
Callable
[[
int
,
tf
.
DType
],
float
],
dtype
:
tf
.
DType
=
tf
.
float32
):
dtype
:
tf
.
DType
=
tf
.
float32
,
decoding_name
:
Optional
[
str
]
=
None
):
"""Initialize the Decoding Module.
"""Initialize the Decoding Module.
Args:
Args:
...
@@ -116,9 +117,11 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
...
@@ -116,9 +117,11 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
parameter. Function accepts input as length, dtype and returns float.
parameter. Function accepts input as length, dtype and returns float.
dtype: A tensorflow data type used for score computation. The default is
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
"""
"""
self
.
length_normalization_fn
=
length_normalization_fn
self
.
length_normalization_fn
=
length_normalization_fn
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
decoding_name
=
decoding_name
def
generate
(
self
,
def
generate
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_ids
:
tf
.
Tensor
,
...
@@ -169,7 +172,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
...
@@ -169,7 +172,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
_generate_step
,
_generate_step
,
loop_vars
=
[
state
],
loop_vars
=
[
state
],
shape_invariants
=
[
state_shapes
],
shape_invariants
=
[
state_shapes
],
parallel_iterations
=
1
))
parallel_iterations
=
1
,
name
=
self
.
decoding_name
))
final_state
=
self
.
_process_finished_state
(
finished_state
[
0
])
final_state
=
self
.
_process_finished_state
(
finished_state
[
0
])
return
final_state
return
final_state
...
@@ -277,6 +281,3 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
...
@@ -277,6 +281,3 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
return
dtypes
.
float16
.
max
return
dtypes
.
float16
.
max
else
:
else
:
raise
AssertionError
(
"Invalid dtype: %s"
%
self
.
dtype
)
raise
AssertionError
(
"Invalid dtype: %s"
%
self
.
dtype
)
official/nlp/modeling/ops/sampling_module.py
View file @
a15ebc46
...
@@ -162,7 +162,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -162,7 +162,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
top_p
=
1.0
,
top_p
=
1.0
,
sample_temperature
=
0.0
,
sample_temperature
=
0.0
,
enable_greedy
:
bool
=
True
,
enable_greedy
:
bool
=
True
,
dtype
:
tf
.
DType
=
tf
.
float32
):
dtype
:
tf
.
DType
=
tf
.
float32
,
decoding_name
:
Optional
[
str
]
=
None
):
"""Initialize sampling module."""
"""Initialize sampling module."""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
length_normalization_fn
=
length_normalization_fn
self
.
length_normalization_fn
=
length_normalization_fn
...
@@ -176,8 +177,11 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -176,8 +177,11 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self
.
sample_temperature
=
tf
.
convert_to_tensor
(
self
.
sample_temperature
=
tf
.
convert_to_tensor
(
sample_temperature
,
dtype
=
tf
.
float32
)
sample_temperature
,
dtype
=
tf
.
float32
)
self
.
enable_greedy
=
enable_greedy
self
.
enable_greedy
=
enable_greedy
self
.
decoding_name
=
decoding_name
super
(
SamplingModule
,
self
).
__init__
(
super
(
SamplingModule
,
self
).
__init__
(
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
)
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
,
decoding_name
=
decoding_name
)
def
_grow_alive_seq
(
self
,
def
_grow_alive_seq
(
self
,
state
:
Dict
[
str
,
Any
],
state
:
Dict
[
str
,
Any
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment