Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cdc4cad7
"tests/vscode:/vscode.git/clone" did not exist on "0af12f1f8a1682833c944354daeba0c9d9c0f342"
Commit
cdc4cad7
authored
Oct 22, 2020
by
Poorva Potdar
Committed by
A. Unique TensorFlower
Oct 22, 2020
Browse files
Internal change
PiperOrigin-RevId: 338521847
parent
4f50e2fc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
832 additions
and
0 deletions
+832
-0
official/nlp/modeling/ops/decoding_module.py
official/nlp/modeling/ops/decoding_module.py
+289
-0
official/nlp/modeling/ops/decoding_module_test.py
official/nlp/modeling/ops/decoding_module_test.py
+84
-0
official/nlp/modeling/ops/sampling_module.py
official/nlp/modeling/ops/sampling_module.py
+376
-0
official/nlp/modeling/ops/sampling_module_test.py
official/nlp/modeling/ops/sampling_module_test.py
+83
-0
No files found.
official/nlp/modeling/ops/decoding_module.py
0 → 100644
View file @
cdc4cad7
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""
import
abc
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
import
tensorflow
as
tf
from
tensorflow.python.framework
import
dtypes
Output
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]
InternalState
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
,
Dict
]
InitialState
=
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]
class
StateKeys
:
"""Keys to dictionary storing the state of Decoding loop."""
# Variable storing the loop index.
CUR_INDEX
=
"CUR_INDEX"
# Top sequences that are alive for each batch item. Alive sequences are ones
# that have not generated an EOS token. Sequences that reach EOS are marked as
# finished and moved to the FINISHED_SEQ tensor.
# Has shape [batch_size, beam_size, CUR_INDEX + 1] for SequenceBeamSearch and
# [batch_size, CUR_INDEX + 1] otherwise.
ALIVE_SEQ
=
"ALIVE_SEQ"
# Log probabilities of each alive sequence. Shape [batch_size, beam_size]
ALIVE_LOG_PROBS
=
"ALIVE_LOG_PROBS"
# Dictionary of cached values for each alive sequence. The cache stores
# the encoder output, attention bias, and the decoder attention output from
# the previous iteration.
ALIVE_CACHE
=
"ALIVE_CACHE"
# Top finished sequences for each batch item.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
# shorter than CUR_INDEX + 1 are padded with 0s.
FINISHED_SEQ
=
"FINISHED_SEQ"
# Scores for each finished sequence. Score = log probability / length norm
# Shape [batch_size, beam_size]
FINISHED_SCORES
=
"FINISHED_SCORES"
# Flags indicating which sequences in the finished sequences are finished.
# At the beginning, all of the sequences in FINISHED_SEQ are filler values.
# True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
FINISHED_FLAGS
=
"FINISHED_FLAGS"
class
DecodingModule
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""A base class for the API required for decoding (go/decoding-tf-nlp)."""
def
__init__
(
self
,
length_normalization_fn
:
Callable
[[
int
,
tf
.
DType
],
float
],
dtype
:
tf
.
DType
=
tf
.
float32
):
"""Initialize the Decoding Module.
Args:
length_normalization_fn: Closure for returning length normalization
parameter. Function accepts input as length, dtype and returns float.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self
.
length_normalization_fn
=
length_normalization_fn
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
def
generate
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_cache
:
Dict
[
str
,
tf
.
Tensor
])
->
Output
:
"""Implements the decoding strategy (beam_search or sampling).
Args:
initial_ids: initial ids to pass into the symbols_to_logits_fn.
int tensor with shape [batch_size, 1]
initial_cache: dictionary for caching model outputs from previous step.
Returns:
Tuple of tensors representing
finished_sequence: shape [batch, max_seq_length]
finished_scores: [batch]
"""
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
self
.
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
state
,
state_shapes
=
self
.
_create_initial_state
(
initial_ids
,
initial_cache
,
batch_size
)
def
_generate_step
(
state
):
topk_seq
,
topk_log_probs
,
topk_ids
,
new_cache
=
self
.
_grow_alive_seq
(
state
,
batch_size
)
new_finished_flags
=
self
.
_finished_flags
(
topk_ids
,
state
)
alive_state
=
self
.
_get_new_alive_state
(
topk_seq
,
topk_log_probs
,
new_finished_flags
,
new_cache
)
finished_state
=
self
.
_get_new_finished_state
(
state
,
topk_seq
,
topk_log_probs
,
new_finished_flags
,
batch_size
)
new_state
=
{
StateKeys
.
CUR_INDEX
:
state
[
StateKeys
.
CUR_INDEX
]
+
1
}
new_state
.
update
(
alive_state
)
new_state
.
update
(
finished_state
)
return
[
new_state
]
finished_state
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
while_loop
(
self
.
_continue_search
,
_generate_step
,
loop_vars
=
[
state
],
shape_invariants
=
[
state_shapes
],
parallel_iterations
=
1
))
final_state
=
self
.
_process_finished_state
(
finished_state
[
0
])
return
final_state
@
abc
.
abstractmethod
def
_create_initial_state
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_cache
:
Dict
[
str
,
tf
.
Tensor
],
batch_size
:
int
)
->
InitialState
:
"""Return initial state dictionary and its shape invariants."""
pass
@
abc
.
abstractmethod
def
_grow_alive_seq
(
self
,
state
:
Dict
[
str
,
Any
],
batch_size
:
int
)
->
InternalState
:
"""Grow alive sequences by one token.
Args:
state: A dictionary with the current loop state.
batch_size: The given batch size
Returns:
Tuple of
(Top sequences,
Scores of returned sequences,
New ids,
New alive cache)
"""
pass
@
abc
.
abstractmethod
def
_get_new_alive_state
(
self
,
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
new_cache
:
Dict
[
str
,
tf
.
Tensor
])
->
Dict
[
str
,
Any
]:
"""Gather the sequences that are still alive.
Args:
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape
new_log_probs: Log probabilities of new sequences float32 tensor with
shape
new_finished_flags: A boolean Tensor indicates which sequences are live.
new_cache: Dict of cached values for each sequence.
Returns:
Dictionary with alive keys from StateKeys.
"""
pass
@
abc
.
abstractmethod
def
_get_new_finished_state
(
self
,
state
:
Dict
[
str
,
Any
],
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
batch_size
:
int
)
->
Dict
[
str
,
tf
.
Tensor
]:
"""Combine new and old finished sequences.
Args:
state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences
int32 tensor.
new_log_probs: Log probabilities of new sequences float32 tensor with
shape.
new_finished_flags: A boolean Tensor indicates which sequences are live.
batch_size: The given batch size.
Returns:
Dictionary with finished keys from StateKeys.
"""
pass
@
abc
.
abstractmethod
def
_process_finished_state
(
self
,
finished_state
:
Dict
[
str
,
Any
])
->
Output
:
"""Process the alive/finished state to return final sequences and scores."""
pass
@
abc
.
abstractmethod
def
_continue_search
(
self
,
state
:
Dict
[
str
,
Any
])
->
tf
.
Tensor
:
"""Returns a bool tensor if the decoding loop should continue."""
pass
@
abc
.
abstractmethod
def
_finished_flags
(
self
,
topk_ids
:
tf
.
Tensor
,
state
:
Dict
[
str
,
Any
])
->
tf
.
Tensor
:
"""Calculate the finished flags."""
pass
def
inf
(
self
):
"""Returns a value close to infinity, but is still finite in `dtype`.
This is useful to get a very large value that is still zero when multiplied
by zero. The floating-point "Inf" value is NaN when multiplied by zero.
Returns:
A very large value.
"""
if
self
.
dtype
==
dtypes
.
float32
or
self
.
dtype
==
dtypes
.
bfloat16
:
return
1e7
elif
self
.
dtype
==
dtypes
.
float16
:
return
dtypes
.
float16
.
max
else
:
raise
AssertionError
(
"Invalid dtype: %s"
%
self
.
dtype
)
@
staticmethod
def
_log_prob_from_logits
(
logits
):
return
logits
-
tf
.
reduce_logsumexp
(
logits
,
axis
=-
1
,
keepdims
=
True
)
@
staticmethod
def
_shape_list
(
tensor
):
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
shape
=
tensor
.
get_shape
().
as_list
()
# Ensure that the shape values are not None
dynamic_shape
=
tf
.
shape
(
tensor
)
for
i
in
range
(
len
(
shape
)):
# pylint: disable=consider-using-enumerate
if
shape
[
i
]
is
None
:
shape
[
i
]
=
dynamic_shape
[
i
]
return
shape
@
staticmethod
def
_get_shape_keep_last_dim
(
tensor
):
shape_list_obj
=
DecodingModule
.
_shape_list
(
tensor
)
for
i
in
range
(
len
(
shape_list_obj
)
-
1
):
shape_list_obj
[
i
]
=
None
if
isinstance
(
shape_list_obj
[
-
1
],
tf
.
Tensor
):
shape_list_obj
[
-
1
]
=
None
return
tf
.
TensorShape
(
shape_list_obj
)
@
staticmethod
def
_expand_to_same_rank
(
tensor
,
target
):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if
tensor
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for tensor shape, but got None."
)
if
target
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for target shape, but got None."
)
with
tf
.
name_scope
(
"expand_rank"
):
diff_rank
=
target
.
shape
.
rank
-
tensor
.
shape
.
rank
for
_
in
range
(
diff_rank
):
tensor
=
tf
.
expand_dims
(
tensor
,
-
1
)
return
tensor
official/nlp/modeling/ops/decoding_module_test.py
0 → 100644
View file @
cdc4cad7
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test decoding utility methods."""
import
abc
import
tensorflow
as
tf
from
official.nlp.modeling.ops
import
decoding_module
def
length_normalization
(
length
,
dtype
):
"""Return length normalization factor."""
return
tf
.
pow
(((
5.
+
tf
.
cast
(
length
,
dtype
))
/
6.
),
0.0
)
class
TestSubclass
(
decoding_module
.
DecodingModule
,
metaclass
=
abc
.
ABCMeta
):
def
__init__
(
self
,
length_normalization_fn
=
length_normalization
,
dtype
=
tf
.
float32
):
super
(
TestSubclass
,
self
).
__init__
(
length_normalization_fn
=
length_normalization
,
dtype
=
dtype
)
def
_create_initial_state
(
self
,
initial_ids
,
initial_cache
,
batch_size
):
pass
def
_grow_alive_seq
(
self
,
state
,
batch_size
):
pass
def
_process_finished_state
(
self
,
finished_state
):
pass
def
_get_new_finished_state
(
self
,
state
,
new_seq
,
new_log_probs
,
new_finished_flags
,
batch_size
):
pass
def
_finished_flags
(
self
,
topk_ids
,
state
):
pass
def
_continue_search
(
self
,
state
):
pass
def
_get_new_alive_state
(
self
,
new_seq
,
new_log_probs
,
new_finished_flags
,
new_cache
):
pass
class
DecodingModuleTest
(
tf
.
test
.
TestCase
):
def
test_get_shape_keep_last_dim
(
self
):
y
=
tf
.
constant
(
4.0
)
x
=
tf
.
ones
([
7
,
tf
.
cast
(
tf
.
sqrt
(
y
),
tf
.
int32
),
2
,
5
])
shape
=
decoding_module
.
DecodingModule
.
_get_shape_keep_last_dim
(
x
)
self
.
assertAllEqual
([
None
,
None
,
None
,
5
],
shape
.
as_list
())
def
test_shape_list
(
self
):
x
=
tf
.
ones
([
7
,
1
])
shape
=
decoding_module
.
DecodingModule
.
_shape_list
(
x
)
self
.
assertAllEqual
([
7
,
1
],
shape
)
def
test_inf
(
self
):
d
=
TestSubclass
()
inf_value
=
d
.
inf
()
self
.
assertAllEqual
(
inf_value
,
tf
.
constant
(
10000000.
,
tf
.
float32
))
def
test_length_normalization
(
self
):
d
=
TestSubclass
()
normalized_length
=
d
.
length_normalization_fn
(
32
,
tf
.
float32
)
self
.
assertAllEqual
(
normalized_length
,
tf
.
constant
(
1.0
,
tf
.
float32
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/modeling/ops/sampling_module.py
0 → 100644
View file @
cdc4cad7
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sampling module for top_k, top_p and greedy decoding."""
import
abc
from
typing
import
Any
,
Callable
,
Dict
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.ops
import
decoding_module
class
SamplingModule
(
decoding_module
.
DecodingModule
,
metaclass
=
abc
.
ABCMeta
):
"""Implementation for sampling stratgies (go/decoding-tf-nlp)."""
def
__init__
(
self
,
symbols_to_logits_fn
,
length_normalization_fn
:
Callable
[[
int
,
tf
.
DType
],
float
],
vocab_size
:
int
,
max_decode_length
:
int
,
eos_id
:
int
,
padded_decode
:
bool
,
top_k
:
tf
.
Tensor
=
None
,
sample_temperature
:
tf
.
Tensor
=
None
,
dtype
:
tf
.
DType
=
tf
.
float32
):
"""Initialize sampling module."""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
vocab_size
=
vocab_size
self
.
length_normalization_fn
=
length_normalization_fn
self
.
max_decode_length
=
max_decode_length
self
.
eos_id
=
eos_id
self
.
padded_decode
=
padded_decode
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
top_k
=
top_k
self
.
sample_temperature
=
sample_temperature
super
(
SamplingModule
,
self
).
__init__
(
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
)
def
_grow_alive_seq
(
self
,
state
:
Dict
[
str
,
Any
],
batch_size
:
int
)
->
decoding_module
.
InternalState
:
"""Grow alive sequences by one token.
This function will implement the decoding strategies like top_p, top_k
and greedy for the choosing the next logit.
Args:
state: A dictionary with the current loop state.
batch_size: The given batch size
Returns:
Tuple of
(Top sequences [batch, curr_index + 1] or [batch, max_decode_length + 1],
Scores of returned sequences [batch, 1],
New ids [batch, 1],
New alive cache)
"""
i
=
state
[
decoding_module
.
StateKeys
.
CUR_INDEX
]
alive_seq
=
state
[
decoding_module
.
StateKeys
.
ALIVE_SEQ
]
alive_log_probs
=
state
[
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
]
alive_cache
=
state
[
decoding_module
.
StateKeys
.
ALIVE_CACHE
]
if
self
.
padded_decode
:
ids
=
tf
.
slice
(
alive_seq
,
[
0
,
i
],
[
batch_size
,
1
])
else
:
ids
=
alive_seq
new_logits
,
new_cache
=
self
.
symbols_to_logits_fn
(
ids
,
i
,
alive_cache
)
candidate_log_probs
=
decoding_module
.
DecodingModule
.
_log_prob_from_logits
(
new_logits
)
original_log_probs
=
candidate_log_probs
+
alive_log_probs
probs
=
original_log_probs
topk_log_probs
,
topk_ids
=
None
,
None
if
not
self
.
do_sample
:
topk_log_probs
,
topk_ids
=
self
.
_greedy
(
probs
)
else
:
temperature_fn
=
SamplingModule
.
sample_logits_with_temperature
probs
=
tf
.
cond
(
self
.
sample_temperature
>
0.0
,
lambda
:
temperature_fn
(
probs
,
self
.
sample_temperature
),
lambda
:
probs
)
probs
=
tf
.
cond
(
self
.
top_k
is
not
None
and
self
.
top_k
>
1
,
lambda
:
SamplingModule
.
_sample_top_k
(
probs
,
self
.
top_k
),
lambda
:
probs
)
topk_ids
=
tf
.
random
.
categorical
(
probs
,
dtype
=
tf
.
int32
,
num_samples
=
1
)
topk_log_probs
=
tf
.
gather
(
original_log_probs
,
topk_ids
,
axis
=
1
,
batch_dims
=
1
)
if
self
.
padded_decode
:
topk_seq
=
tf
.
transpose
(
alive_seq
,
perm
=
[
1
,
0
])
topk_seq
=
tf
.
tensor_scatter_nd_update
(
topk_seq
,
[[
i
+
1
]],
tf
.
expand_dims
(
tf
.
squeeze
(
topk_ids
,
-
1
),
0
))
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
1
,
0
])
else
:
topk_seq
=
tf
.
concat
([
alive_seq
,
topk_ids
],
axis
=-
1
)
return
topk_seq
,
topk_log_probs
,
topk_ids
,
new_cache
def
_create_initial_state
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_cache
:
Dict
[
str
,
tf
.
Tensor
],
batch_size
:
int
)
->
decoding_module
.
InitialState
:
"""Return initial state dictionary and its shape invariants."""
for
key
,
value
in
initial_cache
.
items
():
for
inner_value
in
tf
.
nest
.
flatten
(
value
):
if
inner_value
.
dtype
!=
self
.
dtype
:
raise
TypeError
(
"initial_cache element for key '%s' has dtype %s that does not "
"match SequenceBeamSearch's dtype of %s. Value: %s"
%
(
key
,
value
.
dtype
.
name
,
self
.
dtype
.
name
,
inner_value
))
# Current loop index (starts at 0)
cur_index
=
tf
.
constant
(
0
)
# Alive sequence with shape [batch_size, 1]
alive_seq
=
initial_ids
alive_seq
=
tf
.
expand_dims
(
alive_seq
,
axis
=-
1
)
if
self
.
padded_decode
:
alive_seq
=
tf
.
tile
(
alive_seq
,
[
1
,
self
.
max_decode_length
+
1
])
# Initial log probabilities with shape [batch_size, 1].
initial_log_probs
=
tf
.
constant
([[
0.
]],
dtype
=
self
.
dtype
)
alive_log_probs
=
tf
.
tile
(
initial_log_probs
,
[
batch_size
,
1
])
alive_cache
=
initial_cache
# Initialize tensor storing finished sequences [batch_size, 1, 1].
finished_seq
=
tf
.
zeros
(
tf
.
shape
(
alive_seq
),
tf
.
int32
)
# Set scores of the initial finished seqs to negative infinity.
finished_scores
=
tf
.
zeros
([
batch_size
,
1
],
dtype
=
self
.
dtype
)
# Initialize finished flags with all False values.
finished_flags
=
tf
.
zeros
([
batch_size
,
1
],
tf
.
bool
)
# Create state dictionary and state shapes.
state
=
{
decoding_module
.
StateKeys
.
CUR_INDEX
:
cur_index
,
decoding_module
.
StateKeys
.
ALIVE_SEQ
:
alive_seq
,
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
:
alive_log_probs
,
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
alive_cache
,
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
finished_seq
,
decoding_module
.
StateKeys
.
FINISHED_SCORES
:
finished_scores
,
decoding_module
.
StateKeys
.
FINISHED_FLAGS
:
finished_flags
}
if
self
.
padded_decode
:
state_shape_invariants
=
{
decoding_module
.
StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
decoding_module
.
StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
(
[
batch_size
,
self
.
max_decode_length
+
1
]),
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
batch_size
,
1
]),
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
tf
.
nest
.
map_structure
(
lambda
state
:
state
.
get_shape
(),
alive_cache
),
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
(
[
batch_size
,
self
.
max_decode_length
+
1
]),
decoding_module
.
StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
batch_size
,
1
]),
decoding_module
.
StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
batch_size
,
1
])
}
else
:
state_shape_invariants
=
{
decoding_module
.
StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
decoding_module
.
StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
([
None
,
None
]),
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
1
]),
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
tf
.
nest
.
map_structure
(
decoding_module
.
DecodingModule
.
_get_shape_keep_last_dim
,
alive_cache
),
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
None
]),
decoding_module
.
StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
None
,
1
]),
decoding_module
.
StateKeys
.
FINISHED_FLAGS
:
tf
.
TensorShape
([
None
,
1
])
}
return
state
,
state_shape_invariants
def
_get_new_alive_state
(
self
,
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
new_cache
:
Dict
[
str
,
tf
.
Tensor
])
->
Dict
[
str
,
Any
]:
"""Gather the sequences that are still alive.
This function resets the sequences in the alive_state that are finished.
Args:
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, cur_index + 1]
new_log_probs: Log probabilities of new sequences float32 tensor with
shape [batch_size, 1]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
new_cache: Dict of cached values for each sequence.
Returns:
Dictionary with alive keys.
"""
new_seq
=
tf
.
multiply
(
new_seq
,
tf
.
cast
(
tf
.
logical_not
(
new_finished_flags
),
new_seq
.
dtype
))
return
{
decoding_module
.
StateKeys
.
ALIVE_SEQ
:
new_seq
,
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
:
new_log_probs
,
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
new_cache
}
def
_get_new_finished_state
(
self
,
state
:
Dict
[
str
,
Any
],
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
batch_size
:
int
)
->
Dict
[
str
,
tf
.
Tensor
]:
"""Combine new and old finished sequences.
Args:
state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences
int32 tensor [batch, curr_index + 1] or [batch, max_decode_length + 1].
new_log_probs: Log probabilities of new sequences float32 tensor with
shape [batch, 1].
new_finished_flags: A boolean Tensor indicates which sequences are live.
batch_size: The given batch size.
Returns:
Dictionary with finished keys from StateKeys.
"""
i
=
state
[
decoding_module
.
StateKeys
.
CUR_INDEX
]
finished_seq
=
state
[
decoding_module
.
StateKeys
.
FINISHED_SEQ
]
finished_scores
=
state
[
decoding_module
.
StateKeys
.
FINISHED_SCORES
]
finished_flags
=
state
[
decoding_module
.
StateKeys
.
FINISHED_FLAGS
]
if
not
self
.
padded_decode
:
finished_seq
=
tf
.
concat
(
[
finished_seq
,
tf
.
zeros
([
batch_size
,
1
],
tf
.
int32
)],
axis
=-
1
)
new_scores
=
new_log_probs
if
self
.
length_normalization_fn
is
not
None
:
length_norm
=
self
.
length_normalization_fn
(
i
+
1
,
self
.
dtype
)
new_scores
=
new_log_probs
/
length_norm
new_seq
=
tf
.
multiply
(
new_seq
,
tf
.
cast
(
tf
.
logical_not
(
finished_flags
),
new_seq
.
dtype
))
new_scores
=
tf
.
multiply
(
new_scores
,
tf
.
cast
(
tf
.
logical_not
(
finished_flags
),
new_scores
.
dtype
))
finished_seq
+=
tf
.
multiply
(
new_seq
,
tf
.
cast
(
new_finished_flags
,
new_seq
.
dtype
))
finished_scores
+=
tf
.
multiply
(
new_scores
,
tf
.
cast
(
new_finished_flags
,
new_scores
.
dtype
))
new_finished_flags
=
tf
.
logical_or
(
new_finished_flags
,
finished_flags
)
return
{
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
finished_seq
,
decoding_module
.
StateKeys
.
FINISHED_SCORES
:
finished_scores
,
decoding_module
.
StateKeys
.
FINISHED_FLAGS
:
new_finished_flags
}
def
_process_finished_state
(
self
,
finished_state
:
Dict
[
str
,
Any
])
->
decoding_module
.
Output
:
"""Process the alive/finished state to return final sequences and scores."""
alive_seq
=
finished_state
[
decoding_module
.
StateKeys
.
ALIVE_SEQ
]
alive_log_probs
=
finished_state
[
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
]
finished_seq
=
finished_state
[
decoding_module
.
StateKeys
.
FINISHED_SEQ
]
finished_scores
=
finished_state
[
decoding_module
.
StateKeys
.
FINISHED_SCORES
]
finished_flags
=
finished_state
[
decoding_module
.
StateKeys
.
FINISHED_FLAGS
]
finished_cond
=
tf
.
reduce_any
(
finished_flags
,
1
,
name
=
"finished_cond"
)
if
self
.
length_normalization_fn
is
not
None
:
length_norm
=
self
.
length_normalization_fn
(
self
.
max_decode_length
+
1
,
self
.
dtype
)
alive_log_probs
=
alive_log_probs
/
length_norm
seq_cond
=
decoding_module
.
DecodingModule
.
_expand_to_same_rank
(
finished_cond
,
finished_seq
)
score_cond
=
decoding_module
.
DecodingModule
.
_expand_to_same_rank
(
finished_cond
,
finished_scores
)
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
,
finished_scores
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
return
finished_seq
,
finished_scores
def
_continue_search
(
self
,
state
)
->
tf
.
Tensor
:
i
=
state
[
decoding_module
.
StateKeys
.
CUR_INDEX
]
return
tf
.
less
(
i
,
self
.
max_decode_length
)
def
_finished_flags
(
self
,
topk_ids
,
state
)
->
tf
.
Tensor
:
new_finished_flags
=
tf
.
equal
(
topk_ids
,
self
.
eos_id
)
new_finished_flags
=
tf
.
logical_or
(
new_finished_flags
,
state
[
decoding_module
.
StateKeys
.
FINISHED_FLAGS
])
return
new_finished_flags
@
property
def
do_sample
(
self
)
->
bool
:
"""Returns True if top_p or top_k is enabled."""
# TODO(poorvap) : Add the check for top_p.
if
self
.
top_k
is
not
None
:
return
True
return
False
@
staticmethod
def
_greedy
(
log_probs
):
"""Returns the top ids and scores based on greedy decoding."""
log_probs
,
ids
=
tf
.
nn
.
top_k
(
log_probs
,
k
=
1
)
return
log_probs
,
ids
@
staticmethod
def
sample_logits_with_temperature
(
logits
,
temperature
):
"""Applies a sampling temperature.
Temperature of [0, 1) skews the distribution towards high probability
tokens and lowers the mass in tail distribution.
Args:
logits: Input logits for next token.
temperature: Tensor for specifying the sampling temperature.
Returns:
Logits with applied temperature.
"""
return
logits
/
temperature
@
staticmethod
def
_sample_top_k
(
logits
,
top_k
):
"""Chooses top_k logits and sets the others to negative infinity.
Args:
logits: Input logits for next token.
top_k: Tensor to specify the top_k values.
Returns:
Logits with top_k filtering apploed.
"""
top_k_logits
=
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)
indices_to_remove
=
logits
<
top_k_logits
[
0
][...,
-
1
,
None
]
top_k_logits
=
SamplingModule
.
_set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
np
.
NINF
)
return
top_k_logits
@
staticmethod
def
_set_tensor_by_indices_to_value
(
input_tensor
,
indices
,
value
):
"""Where indices is True, set the value in input_tensor to value.
Args:
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
Returns:
output_tensor: same shape as input_tensor.
"""
value_tensor
=
tf
.
zeros_like
(
input_tensor
)
+
value
output_tensor
=
tf
.
where
(
indices
,
value_tensor
,
input_tensor
)
return
output_tensor
official/nlp/modeling/ops/sampling_module_test.py
0 → 100644
View file @
cdc4cad7
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Sampling Strategies."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.modeling.ops
import
sampling_module
def
length_norm
(
length
,
dtype
):
"""Return length normalization factor."""
return
tf
.
pow
(((
5.
+
tf
.
cast
(
length
,
dtype
))
/
6.
),
0.0
)
class
SamplingModuleTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
cache
=
{
'layer_%d'
%
layer
:
{
'k'
:
tf
.
zeros
([
2
,
2
,
2
,
2
],
dtype
=
tf
.
float32
),
'v'
:
tf
.
zeros
([
2
,
2
,
2
,
2
],
dtype
=
tf
.
float32
)
}
for
layer
in
range
(
2
)}
probabilities
=
tf
.
constant
([[[
0.3
,
0.4
,
0.3
],
[
0.3
,
0.3
,
0.4
],
[
0.1
,
0.1
,
0.8
],
[
0.1
,
0.1
,
0.8
]],
[[
0.2
,
0.4
,
0.4
],
[
0.2
,
0.7
,
0.1
],
[
0.1
,
0.1
,
0.8
],
[
0.1
,
0.1
,
0.8
]]])
def
_get_test_symbols_to_logits_fn
(
self
):
"""Calculates logits of the next tokens."""
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
del
ids
logits
=
tf
.
cast
(
tf
.
math
.
log
(
self
.
probabilities
[:,
i
,
:]),
tf
.
float32
)
return
logits
,
cache
return
symbols_to_logits_fn
@
parameterized
.
named_parameters
([
(
'padded_decode_true'
,
True
),
(
'padded_decode_false'
,
False
),
])
def
test_greedy
(
self
,
padded_decode
):
greedy_obj
=
sampling_module
.
SamplingModule
(
length_normalization_fn
=
None
,
dtype
=
tf
.
float32
,
symbols_to_logits_fn
=
self
.
_get_test_symbols_to_logits_fn
(),
vocab_size
=
3
,
max_decode_length
=
4
,
eos_id
=
10
,
padded_decode
=
padded_decode
)
ids
,
_
=
greedy_obj
.
generate
(
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
self
.
assertAllEqual
([[
9
,
1
,
2
,
2
,
2
],
[
1
,
1
,
1
,
2
,
2
]],
ids
)
@
parameterized
.
named_parameters
([
(
'padded_decode_true'
,
True
),
(
'padded_decode_false'
,
False
),
])
def
test_topk
(
self
,
padded_decode
):
top_k_obj
=
sampling_module
.
SamplingModule
(
length_normalization_fn
=
length_norm
,
dtype
=
tf
.
float32
,
symbols_to_logits_fn
=
self
.
_get_test_symbols_to_logits_fn
(),
vocab_size
=
3
,
max_decode_length
=
4
,
eos_id
=
10
,
sample_temperature
=
tf
.
constant
(
0.1
),
top_k
=
tf
.
constant
(
3
),
padded_decode
=
padded_decode
)
ids
,
_
=
top_k_obj
.
generate
(
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
self
.
assertAllEqual
([
2
,
5
],
ids
.
shape
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment