Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5a4a3ac3
Commit
5a4a3ac3
authored
Nov 24, 2020
by
Poorva Potdar
Committed by
A. Unique TensorFlower
Nov 24, 2020
Browse files
Internal change
PiperOrigin-RevId: 344068616
parent
d9541052
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
307 additions
and
142 deletions
+307
-142
official/nlp/modeling/ops/decoding_module.py
official/nlp/modeling/ops/decoding_module.py
+52
-52
official/nlp/modeling/ops/decoding_module_test.py
official/nlp/modeling/ops/decoding_module_test.py
+2
-2
official/nlp/modeling/ops/sampling_module.py
official/nlp/modeling/ops/sampling_module.py
+154
-83
official/nlp/modeling/ops/sampling_module_test.py
official/nlp/modeling/ops/sampling_module_test.py
+99
-5
No files found.
official/nlp/modeling/ops/decoding_module.py
View file @
5a4a3ac3
...
@@ -58,6 +58,58 @@ class StateKeys:
...
@@ -58,6 +58,58 @@ class StateKeys:
FINISHED_FLAGS
=
"FINISHED_FLAGS"
FINISHED_FLAGS
=
"FINISHED_FLAGS"
def
log_prob_from_logits
(
logits
):
return
logits
-
tf
.
reduce_logsumexp
(
logits
,
axis
=-
1
,
keepdims
=
True
)
def
shape_list
(
tensor
):
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
shape
=
tensor
.
get_shape
().
as_list
()
# Ensure that the shape values are not None
dynamic_shape
=
tf
.
shape
(
tensor
)
for
i
in
range
(
len
(
shape
)):
# pylint: disable=consider-using-enumerate
if
shape
[
i
]
is
None
:
shape
[
i
]
=
dynamic_shape
[
i
]
return
shape
def
get_shape_keep_last_dim
(
tensor
):
shape_list_obj
=
shape_list
(
tensor
)
for
i
in
range
(
len
(
shape_list_obj
)
-
1
):
shape_list_obj
[
i
]
=
None
if
isinstance
(
shape_list_obj
[
-
1
],
tf
.
Tensor
):
shape_list_obj
[
-
1
]
=
None
return
tf
.
TensorShape
(
shape_list_obj
)
def
expand_to_same_rank
(
tensor
,
target
):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if
tensor
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for tensor shape, but got None."
)
if
target
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for target shape, but got None."
)
with
tf
.
name_scope
(
"expand_rank"
):
diff_rank
=
target
.
shape
.
rank
-
tensor
.
shape
.
rank
for
_
in
range
(
diff_rank
):
tensor
=
tf
.
expand_dims
(
tensor
,
-
1
)
return
tensor
class
DecodingModule
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
class
DecodingModule
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""A base class for the API required for decoding (go/decoding-tf-nlp)."""
"""A base class for the API required for decoding (go/decoding-tf-nlp)."""
...
@@ -233,57 +285,5 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
...
@@ -233,57 +285,5 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
else
:
else
:
raise
AssertionError
(
"Invalid dtype: %s"
%
self
.
dtype
)
raise
AssertionError
(
"Invalid dtype: %s"
%
self
.
dtype
)
@
staticmethod
def
_log_prob_from_logits
(
logits
):
return
logits
-
tf
.
reduce_logsumexp
(
logits
,
axis
=-
1
,
keepdims
=
True
)
@
staticmethod
def
_shape_list
(
tensor
):
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
shape
=
tensor
.
get_shape
().
as_list
()
# Ensure that the shape values are not None
dynamic_shape
=
tf
.
shape
(
tensor
)
for
i
in
range
(
len
(
shape
)):
# pylint: disable=consider-using-enumerate
if
shape
[
i
]
is
None
:
shape
[
i
]
=
dynamic_shape
[
i
]
return
shape
@
staticmethod
def
_get_shape_keep_last_dim
(
tensor
):
shape_list_obj
=
DecodingModule
.
_shape_list
(
tensor
)
for
i
in
range
(
len
(
shape_list_obj
)
-
1
):
shape_list_obj
[
i
]
=
None
if
isinstance
(
shape_list_obj
[
-
1
],
tf
.
Tensor
):
shape_list_obj
[
-
1
]
=
None
return
tf
.
TensorShape
(
shape_list_obj
)
@
staticmethod
def
_expand_to_same_rank
(
tensor
,
target
):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if
tensor
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for tensor shape, but got None."
)
if
target
.
shape
.
rank
is
None
:
raise
ValueError
(
"Expect rank for target shape, but got None."
)
with
tf
.
name_scope
(
"expand_rank"
):
diff_rank
=
target
.
shape
.
rank
-
tensor
.
shape
.
rank
for
_
in
range
(
diff_rank
):
tensor
=
tf
.
expand_dims
(
tensor
,
-
1
)
return
tensor
official/nlp/modeling/ops/decoding_module_test.py
View file @
5a4a3ac3
...
@@ -62,12 +62,12 @@ class DecodingModuleTest(tf.test.TestCase):
...
@@ -62,12 +62,12 @@ class DecodingModuleTest(tf.test.TestCase):
def
test_get_shape_keep_last_dim
(
self
):
def
test_get_shape_keep_last_dim
(
self
):
y
=
tf
.
constant
(
4.0
)
y
=
tf
.
constant
(
4.0
)
x
=
tf
.
ones
([
7
,
tf
.
cast
(
tf
.
sqrt
(
y
),
tf
.
int32
),
2
,
5
])
x
=
tf
.
ones
([
7
,
tf
.
cast
(
tf
.
sqrt
(
y
),
tf
.
int32
),
2
,
5
])
shape
=
decoding_module
.
DecodingModule
.
_
get_shape_keep_last_dim
(
x
)
shape
=
decoding_module
.
get_shape_keep_last_dim
(
x
)
self
.
assertAllEqual
([
None
,
None
,
None
,
5
],
shape
.
as_list
())
self
.
assertAllEqual
([
None
,
None
,
None
,
5
],
shape
.
as_list
())
def
test_shape_list
(
self
):
def
test_shape_list
(
self
):
x
=
tf
.
ones
([
7
,
1
])
x
=
tf
.
ones
([
7
,
1
])
shape
=
decoding_module
.
DecodingModule
.
_
shape_list
(
x
)
shape
=
decoding_module
.
shape_list
(
x
)
self
.
assertAllEqual
([
7
,
1
],
shape
)
self
.
assertAllEqual
([
7
,
1
],
shape
)
def
test_inf
(
self
):
def
test_inf
(
self
):
...
...
official/nlp/modeling/ops/sampling_module.py
View file @
5a4a3ac3
...
@@ -23,6 +23,127 @@ import tensorflow as tf
...
@@ -23,6 +23,127 @@ import tensorflow as tf
from
official.nlp.modeling.ops
import
decoding_module
from
official.nlp.modeling.ops
import
decoding_module
def
greedy
(
log_probs
):
"""Returns the top ids and scores based on greedy decoding."""
log_probs
,
ids
=
tf
.
nn
.
top_k
(
log_probs
,
k
=
1
)
return
log_probs
,
ids
def
sample_logits_with_temperature
(
logits
,
temperature
):
"""Applies a sampling temperature.
Temperature skews the distribution towards high probability
tokens and lowers the mass in tail distribution.
Args:
logits: Input logits for next token.
temperature: Tensor for specifying the sampling temperature.
Returns:
Logits with applied temperature.
"""
return
logits
/
temperature
def
sample_top_k
(
logits
,
top_k
):
"""Chooses top_k logits and sets the others to negative infinity.
Args:
logits: Input logits for next token.
top_k: Tensor to specify the top_k values.
Returns:
Logits with top_k filtering applied.
"""
top_k_logits
=
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)
indices_to_remove
=
logits
<
top_k_logits
[
0
][...,
-
1
,
None
]
top_k_logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
np
.
NINF
)
return
top_k_logits
def
sample_top_p
(
logits
,
top_p
):
"""Chooses most probable logits with cumulative probabilities upto top_p.
Sets the remaining logits to negative infinity.
Args:
logits: Input logits for next token.
top_p: Float tensor with a value >=0 and < 1.0
Returns:
Logits with top_p filtering applied.
"""
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
# Flatten logits as tf.gather on TPU needs axis to be compile time constant.
range_for_gather
=
tf
.
expand_dims
(
tf
.
range
(
0
,
logits
.
shape
[
0
]),
axis
=
1
)
range_for_gather
=
tf
.
tile
(
range_for_gather
*
logits
.
shape
[
1
],
[
1
,
logits
.
shape
[
1
]])
+
sorted_indices
flattened_logits
=
tf
.
reshape
(
logits
,
[
-
1
])
flattened_sorted_indices
=
tf
.
reshape
(
range_for_gather
,
[
-
1
])
sorted_logits
=
tf
.
reshape
(
tf
.
gather
(
flattened_logits
,
flattened_sorted_indices
),
[
logits
.
shape
[
0
],
logits
.
shape
[
1
]])
cumulative_probs
=
tf
.
cumsum
(
tf
.
nn
.
softmax
(
sorted_logits
,
axis
=-
1
),
axis
=-
1
)
# Remove tokens with cumulative probability above the threshold.
sorted_indices_to_remove
=
cumulative_probs
>
top_p
# Shift the indices to the right to keep the first token above threshold.
sorted_indices_to_remove
=
tf
.
roll
(
sorted_indices_to_remove
,
1
,
axis
=-
1
)
sorted_indices_to_remove
=
tf
.
concat
([
tf
.
zeros_like
(
sorted_indices_to_remove
[:,
:
1
]),
sorted_indices_to_remove
[:,
1
:]
],
-
1
)
# Scatter sorted indices to original indexes.
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
top_p_logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
np
.
NINF
)
return
top_p_logits
def
scatter_values_on_batch_indices
(
values
,
batch_indices
):
"""Scatter `values` into a tensor using `batch_indices`.
Args:
values: tensor of shape [batch_size, vocab_size] containing the values to
scatter
batch_indices: tensor of shape [batch_size, vocab_size] containing the
indices to insert (should be a permutation in range(0, n))
Returns:
Tensor of shape [batch_size, vocab_size] with values inserted at
batch_indices
"""
tensor_shape
=
decoding_module
.
shape_list
(
batch_indices
)
broad_casted_batch_dims
=
tf
.
reshape
(
tf
.
broadcast_to
(
tf
.
expand_dims
(
tf
.
range
(
tensor_shape
[
0
]),
axis
=-
1
),
tensor_shape
),
[
1
,
-
1
])
pair_indices
=
tf
.
transpose
(
tf
.
concat
([
broad_casted_batch_dims
,
tf
.
reshape
(
batch_indices
,
[
1
,
-
1
])],
0
))
return
tf
.
scatter_nd
(
pair_indices
,
tf
.
reshape
(
values
,
[
-
1
]),
tensor_shape
)
def
set_tensor_by_indices_to_value
(
input_tensor
,
indices
,
value
):
"""Where indices is True, set the value in input_tensor to value.
Args:
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
Returns:
output_tensor: same shape as input_tensor.
"""
value_tensor
=
tf
.
zeros_like
(
input_tensor
)
+
value
output_tensor
=
tf
.
where
(
indices
,
value_tensor
,
input_tensor
)
return
output_tensor
class
SamplingModule
(
decoding_module
.
DecodingModule
,
metaclass
=
abc
.
ABCMeta
):
class
SamplingModule
(
decoding_module
.
DecodingModule
,
metaclass
=
abc
.
ABCMeta
):
"""Implementation for sampling stratgies (go/decoding-tf-nlp)."""
"""Implementation for sampling stratgies (go/decoding-tf-nlp)."""
...
@@ -33,19 +154,25 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -33,19 +154,25 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
max_decode_length
:
int
,
max_decode_length
:
int
,
eos_id
:
int
,
eos_id
:
int
,
padded_decode
:
bool
,
padded_decode
:
bool
,
top_k
:
tf
.
Tensor
=
None
,
top_k
=
0
,
sample_temperature
:
tf
.
Tensor
=
None
,
top_p
=
1.0
,
sample_temperature
=
0.0
,
enable_greedy
:
bool
=
True
,
dtype
:
tf
.
DType
=
tf
.
float32
):
dtype
:
tf
.
DType
=
tf
.
float32
):
"""Initialize sampling module."""
"""Initialize sampling module."""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
vocab_size
=
vocab_size
self
.
length_normalization_fn
=
length_normalization_fn
self
.
length_normalization_fn
=
length_normalization_fn
self
.
max_decode_length
=
max_decode_length
self
.
eos_id
=
eos_id
self
.
eos_id
=
eos_id
self
.
padded_decode
=
padded_decode
self
.
padded_decode
=
padded_decode
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
top_k
=
top_k
self
.
vocab_size
=
tf
.
convert_to_tensor
(
vocab_size
,
dtype
=
tf
.
int32
)
self
.
sample_temperature
=
sample_temperature
self
.
max_decode_length
=
tf
.
convert_to_tensor
(
max_decode_length
,
dtype
=
tf
.
int32
)
self
.
top_k
=
tf
.
convert_to_tensor
(
top_k
,
dtype
=
tf
.
int32
)
self
.
top_p
=
tf
.
convert_to_tensor
(
top_p
,
dtype
=
tf
.
float32
)
self
.
sample_temperature
=
tf
.
convert_to_tensor
(
sample_temperature
,
dtype
=
tf
.
float32
)
self
.
enable_greedy
=
enable_greedy
super
(
SamplingModule
,
self
).
__init__
(
super
(
SamplingModule
,
self
).
__init__
(
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
)
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
)
...
@@ -79,23 +206,29 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -79,23 +206,29 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
ids
=
alive_seq
ids
=
alive_seq
new_logits
,
new_cache
=
self
.
symbols_to_logits_fn
(
ids
,
i
,
alive_cache
)
new_logits
,
new_cache
=
self
.
symbols_to_logits_fn
(
ids
,
i
,
alive_cache
)
candidate_log_probs
=
decoding_module
.
DecodingModule
.
_
log_prob_from_logits
(
candidate_log_probs
=
decoding_module
.
log_prob_from_logits
(
new_logits
)
new_logits
)
original_log_probs
=
candidate_log_probs
+
alive_log_probs
original_log_probs
=
candidate_log_probs
+
alive_log_probs
probs
=
original_log_probs
topk_log_probs
,
topk_ids
=
None
,
None
topk_log_probs
,
topk_ids
=
None
,
None
if
not
self
.
do_sample
:
if
self
.
enable_greedy
:
topk_log_probs
,
topk_ids
=
self
.
_
greedy
(
probs
)
topk_log_probs
,
topk_ids
=
greedy
(
original_log_
probs
)
else
:
else
:
temperature_fn
=
self
.
sample_logits_with_temperature
temperature_fn
=
sample_logits_with_temperature
probs
=
tf
.
cond
(
self
.
sample_temperature
>
0.0
,
sampled_logits
=
tf
.
cond
(
lambda
:
temperature_fn
(
probs
,
self
.
sample_temperature
),
self
.
sample_temperature
>
0.0
,
lambda
:
probs
)
lambda
:
temperature_fn
(
new_logits
,
self
.
sample_temperature
),
probs
=
tf
.
cond
(
self
.
top_k
is
not
None
and
self
.
top_k
>
1
,
lambda
:
new_logits
)
lambda
:
self
.
_sample_top_k
(
probs
,
self
.
top_k
),
sampled_logits
=
tf
.
cond
(
lambda
:
probs
)
self
.
top_k
>
0
,
topk_ids
=
tf
.
random
.
categorical
(
probs
,
dtype
=
tf
.
int32
,
num_samples
=
1
)
lambda
:
sample_top_k
(
sampled_logits
,
self
.
top_k
),
lambda
:
sampled_logits
)
sampled_logits
=
tf
.
cond
(
self
.
top_p
<
1
,
lambda
:
sample_top_p
(
sampled_logits
,
self
.
top_p
),
lambda
:
sampled_logits
)
topk_ids
=
tf
.
random
.
categorical
(
sampled_logits
,
dtype
=
tf
.
int32
,
num_samples
=
1
)
topk_log_probs
=
tf
.
gather
(
topk_log_probs
=
tf
.
gather
(
original_log_probs
,
topk_ids
,
axis
=
1
,
batch_dims
=
1
)
original_log_probs
,
topk_ids
,
axis
=
1
,
batch_dims
=
1
)
if
self
.
padded_decode
:
if
self
.
padded_decode
:
...
@@ -185,7 +318,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -185,7 +318,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
tf
.
TensorShape
([
None
,
1
]),
tf
.
TensorShape
([
None
,
1
]),
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
tf
.
nest
.
map_structure
(
tf
.
nest
.
map_structure
(
decoding_module
.
DecodingModule
.
_
get_shape_keep_last_dim
,
decoding_module
.
get_shape_keep_last_dim
,
alive_cache
),
alive_cache
),
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
None
]),
tf
.
TensorShape
([
None
,
None
]),
...
@@ -288,9 +421,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -288,9 +421,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
length_norm
=
self
.
length_normalization_fn
(
self
.
max_decode_length
+
1
,
length_norm
=
self
.
length_normalization_fn
(
self
.
max_decode_length
+
1
,
self
.
dtype
)
self
.
dtype
)
alive_log_probs
=
alive_log_probs
/
length_norm
alive_log_probs
=
alive_log_probs
/
length_norm
seq_cond
=
decoding_module
.
DecodingModule
.
_
expand_to_same_rank
(
seq_cond
=
decoding_module
.
expand_to_same_rank
(
finished_cond
,
finished_seq
)
finished_cond
,
finished_seq
)
score_cond
=
decoding_module
.
DecodingModule
.
_
expand_to_same_rank
(
score_cond
=
decoding_module
.
expand_to_same_rank
(
finished_cond
,
finished_scores
)
finished_cond
,
finished_scores
)
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
,
finished_scores
)
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
,
finished_scores
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
...
@@ -306,68 +439,6 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -306,68 +439,6 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
new_finished_flags
,
state
[
decoding_module
.
StateKeys
.
FINISHED_FLAGS
])
new_finished_flags
,
state
[
decoding_module
.
StateKeys
.
FINISHED_FLAGS
])
return
new_finished_flags
return
new_finished_flags
@
property
def
do_sample
(
self
)
->
bool
:
"""Returns True if top_p or top_k is enabled."""
# TODO(poorvap) : Add the check for top_p.
if
self
.
top_k
is
not
None
:
return
True
return
False
@
staticmethod
def
_greedy
(
log_probs
):
"""Returns the top ids and scores based on greedy decoding."""
log_probs
,
ids
=
tf
.
nn
.
top_k
(
log_probs
,
k
=
1
)
return
log_probs
,
ids
@
staticmethod
def
sample_logits_with_temperature
(
logits
,
temperature
):
"""Applies a sampling temperature.
Temperature of [0, 1) skews the distribution towards high probability
tokens and lowers the mass in tail distribution.
Args:
logits: Input logits for next token.
temperature: Tensor for specifying the sampling temperature.
Returns:
Logits with applied temperature.
"""
return
logits
/
temperature
@
staticmethod
def
_sample_top_k
(
logits
,
top_k
):
"""Chooses top_k logits and sets the others to negative infinity.
Args:
logits: Input logits for next token.
top_k: Tensor to specify the top_k values.
Returns:
Logits with top_k filtering apploed.
"""
top_k_logits
=
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)
indices_to_remove
=
logits
<
top_k_logits
[
0
][...,
-
1
,
None
]
top_k_logits
=
SamplingModule
.
_set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
np
.
NINF
)
return
top_k_logits
@
staticmethod
def
_set_tensor_by_indices_to_value
(
input_tensor
,
indices
,
value
):
"""Where indices is True, set the value in input_tensor to value.
Args:
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
Returns:
output_tensor: same shape as input_tensor.
"""
value_tensor
=
tf
.
zeros_like
(
input_tensor
)
+
value
output_tensor
=
tf
.
where
(
indices
,
value_tensor
,
input_tensor
)
return
output_tensor
...
...
official/nlp/modeling/ops/sampling_module_test.py
View file @
5a4a3ac3
...
@@ -24,6 +24,8 @@ def length_norm(length, dtype):
...
@@ -24,6 +24,8 @@ def length_norm(length, dtype):
"""Return length normalization factor."""
"""Return length normalization factor."""
return
tf
.
pow
(((
5.
+
tf
.
cast
(
length
,
dtype
))
/
6.
),
0.0
)
return
tf
.
pow
(((
5.
+
tf
.
cast
(
length
,
dtype
))
/
6.
),
0.0
)
greedy_expected
=
tf
.
constant
([[
9
,
1
,
2
,
2
,
2
],
[
1
,
1
,
1
,
2
,
2
]])
class
SamplingModuleTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
SamplingModuleTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
...
@@ -32,7 +34,7 @@ class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -32,7 +34,7 @@ class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
}
for
layer
in
range
(
2
)}
}
for
layer
in
range
(
2
)}
probabilities
=
tf
.
constant
([[[
0.3
,
0.4
,
0.3
],
[
0.3
,
0.3
,
0.4
],
probabilities
=
tf
.
constant
([[[
0.3
,
0.4
,
0.3
],
[
0.3
,
0.3
,
0.4
],
[
0.1
,
0.1
,
0.8
],
[
0.1
,
0.1
,
0.8
]],
[
0.1
,
0.1
,
0.8
],
[
0.1
,
0.1
,
0.8
]],
[[
0.2
,
0.
4
,
0.
4
],
[
0.2
,
0.7
,
0.1
],
[[
0.2
,
0.
5
,
0.
3
],
[
0.2
,
0.7
,
0.1
],
[
0.1
,
0.1
,
0.8
],
[
0.1
,
0.1
,
0.8
]]])
[
0.1
,
0.1
,
0.8
],
[
0.1
,
0.1
,
0.8
]]])
def
_get_test_symbols_to_logits_fn
(
self
):
def
_get_test_symbols_to_logits_fn
(
self
):
...
@@ -58,7 +60,7 @@ class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -58,7 +60,7 @@ class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
padded_decode
=
padded_decode
)
padded_decode
=
padded_decode
)
ids
,
_
=
greedy_obj
.
generate
(
ids
,
_
=
greedy_obj
.
generate
(
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
self
.
assertAllEqual
(
[[
9
,
1
,
2
,
2
,
2
],
[
1
,
1
,
1
,
2
,
2
]]
,
ids
)
self
.
assertAllEqual
(
greedy_expected
,
ids
)
@
parameterized
.
named_parameters
([
@
parameterized
.
named_parameters
([
(
'padded_decode_true'
,
True
),
(
'padded_decode_true'
,
True
),
...
@@ -72,12 +74,104 @@ class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -72,12 +74,104 @@ class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
vocab_size
=
3
,
vocab_size
=
3
,
max_decode_length
=
4
,
max_decode_length
=
4
,
eos_id
=
10
,
eos_id
=
10
,
sample_temperature
=
tf
.
constant
(
0.1
),
sample_temperature
=
tf
.
constant
(
1.0
),
top_k
=
tf
.
constant
(
3
),
top_k
=
tf
.
constant
(
3
),
padded_decode
=
padded_decode
)
padded_decode
=
padded_decode
,
enable_greedy
=
False
)
tf
.
random
.
set_seed
(
1
)
ids
,
_
=
top_k_obj
.
generate
(
ids
,
_
=
top_k_obj
.
generate
(
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
self
.
assertAllEqual
([
2
,
5
],
ids
.
shape
)
top_k_expected
=
tf
.
constant
([[
9
,
1
,
0
,
2
,
2
],
[
1
,
0
,
1
,
1
,
0
]])
self
.
assertAllEqual
(
top_k_expected
,
ids
)
@
parameterized
.
named_parameters
([
(
'padded_decode_true'
,
True
),
(
'padded_decode_false'
,
False
),
])
def
test_topp
(
self
,
padded_decode
):
top_p_obj
=
sampling_module
.
SamplingModule
(
length_normalization_fn
=
length_norm
,
dtype
=
tf
.
float32
,
symbols_to_logits_fn
=
self
.
_get_test_symbols_to_logits_fn
(),
vocab_size
=
3
,
max_decode_length
=
4
,
eos_id
=
10
,
sample_temperature
=
tf
.
constant
(
1.0
),
top_p
=
tf
.
constant
(
0.9
),
padded_decode
=
padded_decode
,
enable_greedy
=
False
)
tf
.
random
.
set_seed
(
1
)
ids
,
_
=
top_p_obj
.
generate
(
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
top_p_expected
=
tf
.
constant
([[
9
,
1
,
0
,
2
,
2
],
[
1
,
0
,
1
,
2
,
0
]])
self
.
assertAllEqual
(
top_p_expected
,
ids
)
@
parameterized
.
named_parameters
([
(
'padded_decode_true'
,
True
),
(
'padded_decode_false'
,
False
),
])
def
test_sampling_equivalent_greedy
(
self
,
padded_decode
):
# Ensure that p=0.0 with no sample temperature is same as greedy.
top_p_obj
=
sampling_module
.
SamplingModule
(
length_normalization_fn
=
length_norm
,
dtype
=
tf
.
float32
,
symbols_to_logits_fn
=
self
.
_get_test_symbols_to_logits_fn
(),
vocab_size
=
3
,
max_decode_length
=
4
,
eos_id
=
10
,
sample_temperature
=
0.0
,
top_p
=
tf
.
constant
(
0.0
),
padded_decode
=
padded_decode
,
enable_greedy
=
False
)
ids
,
_
=
top_p_obj
.
generate
(
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
self
.
assertAllEqual
(
greedy_expected
,
ids
)
# Ensure that k=1 with no sample temperature is same as greedy.
top_k_obj
=
sampling_module
.
SamplingModule
(
length_normalization_fn
=
length_norm
,
dtype
=
tf
.
float32
,
symbols_to_logits_fn
=
self
.
_get_test_symbols_to_logits_fn
(),
vocab_size
=
3
,
max_decode_length
=
4
,
eos_id
=
10
,
sample_temperature
=
0.0
,
top_k
=
tf
.
constant
(
1
),
padded_decode
=
padded_decode
,
enable_greedy
=
False
)
ids
,
_
=
top_k_obj
.
generate
(
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
# Ensure that low sample temperature results in Sharp Distribution (greedy).
low_temperature_obj
=
sampling_module
.
SamplingModule
(
length_normalization_fn
=
length_norm
,
dtype
=
tf
.
float32
,
symbols_to_logits_fn
=
self
.
_get_test_symbols_to_logits_fn
(),
vocab_size
=
3
,
max_decode_length
=
4
,
eos_id
=
10
,
sample_temperature
=
0.0001
,
padded_decode
=
padded_decode
)
ids
,
_
=
low_temperature_obj
.
generate
(
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
self
.
assertAllEqual
(
greedy_expected
,
ids
)
# Ensure that high sample temperature results in Flat Distribution (random).
high_temperature_obj
=
sampling_module
.
SamplingModule
(
length_normalization_fn
=
length_norm
,
dtype
=
tf
.
float32
,
symbols_to_logits_fn
=
self
.
_get_test_symbols_to_logits_fn
(),
vocab_size
=
3
,
max_decode_length
=
4
,
eos_id
=
10
,
sample_temperature
=
10.0
,
padded_decode
=
padded_decode
,
enable_greedy
=
False
)
tf
.
random
.
set_seed
(
1
)
ids
,
_
=
high_temperature_obj
.
generate
(
initial_ids
=
tf
.
constant
([
9
,
1
]),
initial_cache
=
self
.
cache
)
expected
=
tf
.
constant
([[
9
,
0
,
0
,
2
,
2
],
[
1
,
0
,
0
,
0
,
0
]])
self
.
assertAllEqual
(
expected
,
ids
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment