Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fb0ae129
Unverified
Commit
fb0ae129
authored
Apr 29, 2022
by
Joao Gante
Committed by
GitHub
Apr 29, 2022
Browse files
TF: XLA bad words logits processor and list of processors (#16974)
parent
57e6464a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
75 deletions
+116
-75
src/transformers/generation_tf_logits_process.py
src/transformers/generation_tf_logits_process.py
+80
-55
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+4
-4
tests/generation/test_generation_tf_logits_process.py
tests/generation/test_generation_tf_logits_process.py
+32
-16
No files found.
src/transformers/generation_tf_logits_process.py
View file @
fb0ae129
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
import
inspect
import
inspect
from
typing
import
List
from
typing
import
List
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -38,7 +38,10 @@ TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
...
@@ -38,7 +38,10 @@ TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
[What are input IDs?](../glossary#input-ids)
[What are input IDs?](../glossary#input-ids)
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
search or log softmax for each vocabulary token when using beam search.
cur_len (`int`):
The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
is the maximum length generate can produce, and we need to know which of its tokens are valid.
kwargs:
kwargs:
Additional logits processor specific kwargs.
Additional logits processor specific kwargs.
...
@@ -51,7 +54,7 @@ class TFLogitsProcessor:
...
@@ -51,7 +54,7 @@ class TFLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""
"""Abstract base class for all logit processors that can be applied during generation."""
@
add_start_docstrings
(
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
"""TF method for processing logits."""
"""TF method for processing logits."""
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"
{
self
.
__class__
}
is an abstract class. Only classes inheriting this class can be called."
f
"
{
self
.
__class__
}
is an abstract class. Only classes inheriting this class can be called."
...
@@ -62,7 +65,7 @@ class TFLogitsWarper:
...
@@ -62,7 +65,7 @@ class TFLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@
add_start_docstrings
(
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
"""TF method for warping logits."""
"""TF method for warping logits."""
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"
{
self
.
__class__
}
is an abstract class. Only classes inheriting this class can be called."
f
"
{
self
.
__class__
}
is an abstract class. Only classes inheriting this class can be called."
...
@@ -77,18 +80,18 @@ class TFLogitsProcessorList(list):
...
@@ -77,18 +80,18 @@ class TFLogitsProcessorList(list):
"""
"""
@
add_start_docstrings
(
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
**
kwargs
)
->
tf
.
Tensor
:
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
,
**
kwargs
)
->
tf
.
Tensor
:
for
processor
in
self
:
for
processor
in
self
:
function_args
=
inspect
.
signature
(
processor
.
__call__
).
parameters
function_args
=
inspect
.
signature
(
processor
.
__call__
).
parameters
if
len
(
function_args
)
>
2
:
if
len
(
function_args
)
>
3
:
if
not
all
(
arg
in
kwargs
for
arg
in
list
(
function_args
.
keys
())[
2
:]):
if
not
all
(
arg
in
kwargs
for
arg
in
list
(
function_args
.
keys
())[
2
:]):
raise
ValueError
(
raise
ValueError
(
f
"Make sure that all the required parameters:
{
list
(
function_args
.
keys
())
}
for "
f
"Make sure that all the required parameters:
{
list
(
function_args
.
keys
())
}
for "
f
"
{
processor
.
__class__
}
are passed to the logits processor."
f
"
{
processor
.
__class__
}
are passed to the logits processor."
)
)
scores
=
processor
(
input_ids
,
scores
,
**
kwargs
)
scores
=
processor
(
input_ids
,
scores
,
cur_len
,
**
kwargs
)
else
:
else
:
scores
=
processor
(
input_ids
,
scores
)
scores
=
processor
(
input_ids
,
scores
,
cur_len
)
return
scores
return
scores
...
@@ -107,7 +110,7 @@ class TFTemperatureLogitsWarper(TFLogitsWarper):
...
@@ -107,7 +110,7 @@ class TFTemperatureLogitsWarper(TFLogitsWarper):
self
.
temperature
=
temperature
self
.
temperature
=
temperature
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
scores
=
scores
/
self
.
temperature
scores
=
scores
/
self
.
temperature
return
scores
return
scores
...
@@ -133,7 +136,7 @@ class TFTopKLogitsWarper(TFLogitsWarper):
...
@@ -133,7 +136,7 @@ class TFTopKLogitsWarper(TFLogitsWarper):
self
.
filter_value
=
filter_value
self
.
filter_value
=
filter_value
self
.
min_tokens_to_keep
=
min_tokens_to_keep
self
.
min_tokens_to_keep
=
min_tokens_to_keep
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
top_k
=
min
(
max
(
self
.
top_k
,
self
.
min_tokens_to_keep
),
scores
.
shape
[
-
1
])
# Safety check
top_k
=
min
(
max
(
self
.
top_k
,
self
.
min_tokens_to_keep
),
scores
.
shape
[
-
1
])
# Safety check
# Boolean mask containing all tokens with a probability less than the last token of the top-k
# Boolean mask containing all tokens with a probability less than the last token of the top-k
indices_to_remove
=
scores
<
tf
.
math
.
top_k
(
scores
,
k
=
top_k
)[
0
][...,
-
1
:]
indices_to_remove
=
scores
<
tf
.
math
.
top_k
(
scores
,
k
=
top_k
)[
0
][...,
-
1
:]
...
@@ -163,7 +166,7 @@ class TFTopPLogitsWarper(TFLogitsWarper):
...
@@ -163,7 +166,7 @@ class TFTopPLogitsWarper(TFLogitsWarper):
self
.
filter_value
=
filter_value
self
.
filter_value
=
filter_value
self
.
min_tokens_to_keep
=
min_tokens_to_keep
self
.
min_tokens_to_keep
=
min_tokens_to_keep
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
topk_scores
,
topk_indices
=
tf
.
math
.
top_k
(
scores
,
scores
.
shape
[
-
1
])
topk_scores
,
topk_indices
=
tf
.
math
.
top_k
(
scores
,
scores
.
shape
[
-
1
])
mask_scores
=
tf
.
fill
(
scores
.
shape
,
self
.
filter_value
)
mask_scores
=
tf
.
fill
(
scores
.
shape
,
self
.
filter_value
)
...
@@ -305,58 +308,75 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
...
@@ -305,58 +308,75 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
f
"Each list in `bad_words_ids` has to be a list of positive integers, but is
{
bad_words_ids
}
."
f
"Each list in `bad_words_ids` has to be a list of positive integers, but is
{
bad_words_ids
}
."
)
)
self
.
bad_words_ids
=
bad_words_ids
# stores the information about bad words in three tensors:
# 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons
def
calc_banned_bad_words_ids
(
self
,
prev_input_ids
):
self
.
bad_word_seqs_ids
=
tf
.
ragged
.
constant
(
bad_words_ids
).
to_tensor
(
default_value
=-
1
)
banned_tokens
=
[]
# 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
bad_word_seqs_len
=
[
len
(
bad_words
)
for
bad_words
in
bad_words_ids
]
def
_tokens_match
(
prev_tokens
,
tokens
):
if
any
([
word_len
==
0
for
word_len
in
bad_word_seqs_len
]):
if
len
(
tokens
)
==
0
:
raise
ValueError
(
f
"Banned words token sequences
{
bad_words_ids
}
cannot have an empty list"
)
# if bad word tokens is just one token always ban it
self
.
bad_word_seqs_len
=
tf
.
convert_to_tensor
(
bad_word_seqs_len
,
dtype
=
tf
.
int32
)
return
True
# 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
if
len
(
tokens
)
>
len
(
prev_tokens
):
self
.
seq_forbidden_tokens
=
tf
.
convert_to_tensor
([
bad_words
[
-
1
]
for
bad_words
in
bad_words_ids
])
# if bad word tokens are longer than prev tokens they can't be equal
return
False
def
_calc_row_banned_bad_tokens
(
self
,
row_input_ids
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
_tokens_match
(
bad_word_seq_number
):
if
prev_tokens
[
-
len
(
tokens
)
:]
==
tokens
:
def
_len_one
():
# if tokens match
# If the bad sequence only has one token, always mask it
return
True
return
tf
.
cond
(
else
:
tf
.
math
.
equal
(
self
.
bad_word_seqs_len
[
bad_word_seq_number
],
1
),
return
False
lambda
:
tf
.
ones
((),
dtype
=
tf
.
bool
),
_len_greater_than_cur_len
,
for
prev_input_ids_slice
in
prev_input_ids
:
)
banned_tokens_slice
=
[]
for
banned_token_seq
in
self
.
bad_words_ids
:
assert
(
len
(
banned_token_seq
)
>
0
),
f
"Banned words token sequences
{
self
.
bad_words_ids
}
cannot have an empty list"
if
_tokens_match
(
prev_input_ids_slice
.
numpy
().
tolist
(),
banned_token_seq
[:
-
1
])
is
False
:
def
_len_greater_than_cur_len
():
# if tokens do not match continue
# Otherwise, if the bad sequence is longer than the current length they can't ever match
continue
return
tf
.
cond
(
tf
.
math
.
greater
(
self
.
bad_word_seqs_len
[
bad_word_seq_number
],
row_input_ids
.
shape
[
0
]),
lambda
:
tf
.
zeros
((),
dtype
=
tf
.
bool
),
_match_found
,
)
banned_tokens_slice
.
append
(
banned_token_seq
[
-
1
])
def
_match_found
():
# Finaly, runs the actual comparison. Can only be called if the previous comparisons do not yield
# an answer (otherwise we get indexing exceptions)
compare_len
=
self
.
bad_word_seqs_len
[
bad_word_seq_number
]
-
1
return
tf
.
cond
(
tf
.
math
.
reduce_all
(
tf
.
math
.
equal
(
row_input_ids
[
-
compare_len
:],
self
.
bad_word_seqs_ids
[
bad_word_seq_number
,
:
compare_len
]
)
),
lambda
:
tf
.
ones
((),
dtype
=
tf
.
bool
),
lambda
:
tf
.
zeros
((),
dtype
=
tf
.
bool
),
)
banned_tokens
.
append
(
banned_tokens_slice
)
match
=
_len_one
()
return
match
return
banned_tokens
# Compares the current row against all bad word sequences, obtaining a mask with the matches.
match_mask
=
tf
.
map_fn
(
_tokens_match
,
tf
.
range
(
self
.
bad_word_seqs_ids
.
shape
[
0
]),
fn_output_signature
=
tf
.
bool
)
row_banned_tokens
=
self
.
seq_forbidden_tokens
[
match_mask
]
return
row_banned_tokens
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
# We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous
vocab_size
=
scores
.
shape
[
-
1
]
# `input_ids`, they may have a different length for each row, and they may even be empty for some rows.
# To remain simple and XLA-compatible, we work on a per-row fashion.
# calculate a list of banned tokens according to bad words
# TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes
banned_tokens
=
self
.
calc_banned_bad_words_ids
(
input_ids
[:,
:
cur_len
])
# a frequent choke point. (make `cur_len` a tensor?)
def
_get_row_updated_score
(
row_inputs
:
Tuple
[
tf
.
Tensor
])
->
tf
.
Tensor
:
banned_tokens_indices_mask
=
[]
row_input_ids
,
row_score
=
row_inputs
for
banned_tokens_slice
in
banned_tokens
:
banned_tokens
=
self
.
_calc_row_banned_bad_tokens
(
row_input_ids
[:
cur_len
])
banned_tokens_indices_mask
.
append
(
banned_tokens_mask
=
tf
.
scatter_nd
(
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
indices
=
tf
.
expand_dims
(
banned_tokens
,
axis
=-
1
),
updates
=
tf
.
ones_like
(
banned_tokens
,
dtype
=
tf
.
bool
),
shape
=
row_score
.
shape
,
)
)
row_score
=
tf
.
where
(
banned_tokens_mask
,
-
float
(
"inf"
),
row_score
)
return
row_score
scores
=
tf
.
where
(
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
),
scores
)
scores
=
tf
.
map_fn
(
_get_row_updated_score
,
(
input_ids
,
scores
),
fn_output_signature
=
tf
.
float32
)
return
scores
return
scores
...
@@ -401,6 +421,11 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
...
@@ -401,6 +421,11 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
# TODO (joao): enable XLA on this logits processor. See discussion and attempts in
# https://github.com/huggingface/transformers/pull/16974
if
not
tf
.
executing_eagerly
():
raise
NotImplementedError
(
"TFNoRepeatNGramLogitsProcessor is only implemented for eager execution."
)
batch_size
,
vocab_size
=
scores
.
shape
batch_size
,
vocab_size
=
scores
.
shape
banned_tokens
=
self
.
calc_banned_ngram_tokens
(
input_ids
,
batch_size
,
cur_len
)
banned_tokens
=
self
.
calc_banned_ngram_tokens
(
input_ids
,
batch_size
,
cur_len
)
...
...
src/transformers/generation_tf_utils.py
View file @
fb0ae129
...
@@ -2030,7 +2030,7 @@ class TFGenerationMixin:
...
@@ -2030,7 +2030,7 @@ class TFGenerationMixin:
if
not
use_xla
:
if
not
use_xla
:
input_ids
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
input_ids
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
input_ids
=
tf
.
transpose
(
input_ids
[:
current_pos
[
0
]])
input_ids
=
tf
.
transpose
(
input_ids
[:
current_pos
[
0
]])
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
=
current_pos
[
0
])
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
current_pos
[
0
])
# argmax
# argmax
next_tokens
=
tf
.
argmax
(
next_tokens_scores
,
axis
=-
1
,
output_type
=
tf
.
int32
)
next_tokens
=
tf
.
argmax
(
next_tokens_scores
,
axis
=-
1
,
output_type
=
tf
.
int32
)
...
@@ -2301,8 +2301,8 @@ class TFGenerationMixin:
...
@@ -2301,8 +2301,8 @@ class TFGenerationMixin:
if
not
use_xla
:
if
not
use_xla
:
input_ids
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
input_ids
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
input_ids
=
tf
.
transpose
(
input_ids
[:
cur_len
])
input_ids
=
tf
.
transpose
(
input_ids
[:
cur_len
])
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
=
cur_len
)
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
)
next_tokens_scores
=
logits_warper
(
input_ids
,
next_tokens_scores
)
next_tokens_scores
=
logits_warper
(
input_ids
,
next_tokens_scores
,
cur_len
)
# sample
# sample
if
seed
is
not
None
:
if
seed
is
not
None
:
...
@@ -2726,7 +2726,7 @@ class TFGenerationMixin:
...
@@ -2726,7 +2726,7 @@ class TFGenerationMixin:
# add new logprobs to existing running logprobs scores.
# add new logprobs to existing running logprobs scores.
log_probs
=
tf
.
nn
.
log_softmax
(
logits
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
)
log_probs
=
logits_processor
(
log_probs
=
logits_processor
(
flatten_beam_dim
(
running_sequences_seq_last
),
flatten_beam_dim
(
log_probs
),
cur_len
=
cur_len
flatten_beam_dim
(
running_sequences_seq_last
),
flatten_beam_dim
(
log_probs
),
cur_len
)
)
log_probs
=
unflatten_beam_dim
(
log_probs
,
batch_size
,
num_beams
)
log_probs
=
unflatten_beam_dim
(
log_probs
,
batch_size
,
num_beams
)
log_probs
=
log_probs
+
tf
.
expand_dims
(
running_scores
,
axis
=
2
)
log_probs
=
log_probs
+
tf
.
expand_dims
(
running_scores
,
axis
=
2
)
...
...
tests/generation/test_generation_tf_logits_process.py
View file @
fb0ae129
...
@@ -75,6 +75,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -75,6 +75,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_temperature_dist_warper
(
self
,
use_xla
):
def
test_temperature_dist_warper
(
self
,
use_xla
):
input_ids
=
None
input_ids
=
None
cur_len
=
None
length
=
20
length
=
20
scores
=
self
.
_get_uniform_logits
(
batch_size
=
2
,
length
=
length
)
scores
=
self
.
_get_uniform_logits
(
batch_size
=
2
,
length
=
length
)
...
@@ -94,8 +95,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -94,8 +95,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
temp_dist_warper_sharper
=
tf
.
function
(
temp_dist_warper_sharper
,
jit_compile
=
True
)
temp_dist_warper_sharper
=
tf
.
function
(
temp_dist_warper_sharper
,
jit_compile
=
True
)
temp_dist_warper_smoother
=
tf
.
function
(
temp_dist_warper_smoother
,
jit_compile
=
True
)
temp_dist_warper_smoother
=
tf
.
function
(
temp_dist_warper_smoother
,
jit_compile
=
True
)
warped_prob_sharp
=
tf
.
nn
.
softmax
(
temp_dist_warper_sharper
(
input_ids
,
tf
.
identity
(
scores
)),
axis
=-
1
)
warped_prob_sharp
=
tf
.
nn
.
softmax
(
temp_dist_warper_sharper
(
input_ids
,
tf
.
identity
(
scores
)
,
cur_len
),
axis
=-
1
)
warped_prob_smooth
=
tf
.
nn
.
softmax
(
temp_dist_warper_smoother
(
input_ids
,
tf
.
identity
(
scores
)),
axis
=-
1
)
warped_prob_smooth
=
tf
.
nn
.
softmax
(
temp_dist_warper_smoother
(
input_ids
,
tf
.
identity
(
scores
)
,
cur_len
),
axis
=-
1
)
# uniform distribution stays uniform
# uniform distribution stays uniform
tf
.
debugging
.
assert_near
(
probs
[
0
,
:],
warped_prob_sharp
[
0
,
:],
atol
=
1e-3
)
tf
.
debugging
.
assert_near
(
probs
[
0
,
:],
warped_prob_sharp
[
0
,
:],
atol
=
1e-3
)
...
@@ -142,6 +143,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -142,6 +143,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_top_k_dist_warper
(
self
,
use_xla
):
def
test_top_k_dist_warper
(
self
,
use_xla
):
input_ids
=
None
input_ids
=
None
cur_len
=
None
vocab_size
=
10
vocab_size
=
10
batch_size
=
2
batch_size
=
2
...
@@ -153,7 +155,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -153,7 +155,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
if
use_xla
:
if
use_xla
:
top_k_warp
=
tf
.
function
(
top_k_warp
,
jit_compile
=
True
)
top_k_warp
=
tf
.
function
(
top_k_warp
,
jit_compile
=
True
)
scores
=
top_k_warp
(
input_ids
,
ramp_logits
)
scores
=
top_k_warp
(
input_ids
,
ramp_logits
,
cur_len
)
# check that correct tokens are filtered
# check that correct tokens are filtered
self
.
assertListEqual
(
tf
.
math
.
is_inf
(
scores
[
0
]).
numpy
().
tolist
(),
7
*
[
True
]
+
3
*
[
False
])
self
.
assertListEqual
(
tf
.
math
.
is_inf
(
scores
[
0
]).
numpy
().
tolist
(),
7
*
[
True
]
+
3
*
[
False
])
...
@@ -167,12 +169,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -167,12 +169,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
if
use_xla
:
if
use_xla
:
top_k_warp_safety_check
=
tf
.
function
(
top_k_warp_safety_check
,
jit_compile
=
True
)
top_k_warp_safety_check
=
tf
.
function
(
top_k_warp_safety_check
,
jit_compile
=
True
)
scores
=
top_k_warp_safety_check
(
input_ids
,
logits
)
scores
=
top_k_warp_safety_check
(
input_ids
,
logits
,
cur_len
)
# uniform dist is not changed
# uniform dist is not changed
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
0
,
0
])
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
0
,
0
])
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
length
,
dtype
=
np
.
float32
),
(
batch_size
,
length
)).
copy
()
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
length
,
dtype
=
np
.
float32
),
(
batch_size
,
length
)).
copy
()
scores
=
top_k_warp_safety_check
(
input_ids
,
ramp_logits
)
scores
=
top_k_warp_safety_check
(
input_ids
,
ramp_logits
,
cur_len
)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
2
,
2
])
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
2
,
2
])
...
@@ -180,6 +182,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -180,6 +182,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_top_p_dist_warper
(
self
,
use_xla
):
def
test_top_p_dist_warper
(
self
,
use_xla
):
input_ids
=
None
input_ids
=
None
cur_len
=
None
vocab_size
=
10
vocab_size
=
10
batch_size
=
2
batch_size
=
2
...
@@ -189,7 +192,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -189,7 +192,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
top_p_warp
=
TFTopPLogitsWarper
(
0.7
)
top_p_warp
=
TFTopPLogitsWarper
(
0.7
)
if
use_xla
:
if
use_xla
:
top_p_warp
=
tf
.
function
(
top_p_warp
,
jit_compile
=
True
)
top_p_warp
=
tf
.
function
(
top_p_warp
,
jit_compile
=
True
)
filtered_dist
=
tf
.
exp
(
top_p_warp
(
input_ids
,
dist
))
filtered_dist
=
tf
.
exp
(
top_p_warp
(
input_ids
,
dist
,
cur_len
))
# dist should be filtered to keep min num values so that sum is >= 0.7
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
# exp (-inf) => 0
...
@@ -208,7 +211,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -208,7 +211,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
top_p_warp
=
TFTopPLogitsWarper
(
0.9
,
min_tokens_to_keep
=
2
,
filter_value
=
0.0
)
top_p_warp
=
TFTopPLogitsWarper
(
0.9
,
min_tokens_to_keep
=
2
,
filter_value
=
0.0
)
if
use_xla
:
if
use_xla
:
top_p_warp
=
tf
.
function
(
top_p_warp
,
jit_compile
=
True
)
top_p_warp
=
tf
.
function
(
top_p_warp
,
jit_compile
=
True
)
filtered_dist
=
top_p_warp
(
input_ids
,
ramp_logits
)
filtered_dist
=
top_p_warp
(
input_ids
,
ramp_logits
,
cur_len
)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
# 2.
# 2.
...
@@ -242,7 +245,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -242,7 +245,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
tf
.
math
.
is_inf
(
filtered_scores_3_gram
).
numpy
().
tolist
(),
[[
False
,
False
,
False
],
[
True
,
False
,
False
]]
tf
.
math
.
is_inf
(
filtered_scores_3_gram
).
numpy
().
tolist
(),
[[
False
,
False
,
False
],
[
True
,
False
,
False
]]
)
)
def
test_no_bad_words_dist_processor
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_no_bad_words_dist_processor
(
self
,
use_xla
):
vocab_size
=
5
vocab_size
=
5
batch_size
=
2
batch_size
=
2
eos_token_id
=
4
eos_token_id
=
4
...
@@ -255,6 +259,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -255,6 +259,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores
=
self
.
_get_uniform_logits
(
batch_size
,
vocab_size
)
scores
=
self
.
_get_uniform_logits
(
batch_size
,
vocab_size
)
no_bad_words_dist_proc
=
TFNoBadWordsLogitsProcessor
(
bad_words_ids
=
bad_word_tokens
,
eos_token_id
=
eos_token_id
)
no_bad_words_dist_proc
=
TFNoBadWordsLogitsProcessor
(
bad_words_ids
=
bad_word_tokens
,
eos_token_id
=
eos_token_id
)
if
use_xla
:
no_bad_words_dist_proc
=
tf
.
function
(
no_bad_words_dist_proc
,
jit_compile
=
True
)
filtered_scores
=
no_bad_words_dist_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
filtered_scores
=
no_bad_words_dist_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
...
@@ -322,7 +328,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -322,7 +328,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores
=
logits_processor
(
input_ids
,
scores
,
cur_len
)
scores
=
logits_processor
(
input_ids
,
scores
,
cur_len
)
self
.
assertFalse
(
tf
.
math
.
reduce_any
(
tf
.
math
.
is_inf
((
scores
))))
self
.
assertFalse
(
tf
.
math
.
reduce_any
(
tf
.
math
.
is_inf
((
scores
))))
def
test_processor_list
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_processor_list
(
self
,
use_xla
):
# TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
batch_size
=
4
batch_size
=
4
cur_len
=
10
cur_len
=
10
vocab_size
=
15
vocab_size
=
15
...
@@ -341,16 +349,24 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -341,16 +349,24 @@ class TFLogitsProcessorTest(unittest.TestCase):
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
top_k_warp
=
TFTopKLogitsWarper
(
3
)
top_k_warp
=
TFTopKLogitsWarper
(
3
)
top_p_warp
=
TFTopPLogitsWarper
(
0.8
)
top_p_warp
=
TFTopPLogitsWarper
(
0.8
)
no_repeat_proc
=
TFNoRepeatNGramLogitsProcessor
(
2
)
#
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc
=
TFNoBadWordsLogitsProcessor
(
bad_words_ids
=
[[
1
]],
eos_token_id
=
eos_token_id
)
no_bad_words_dist_proc
=
TFNoBadWordsLogitsProcessor
(
bad_words_ids
=
[[
1
]],
eos_token_id
=
eos_token_id
)
if
use_xla
:
min_dist_proc
=
tf
.
function
(
min_dist_proc
,
jit_compile
=
True
)
temp_dist_warp
=
tf
.
function
(
temp_dist_warp
,
jit_compile
=
True
)
rep_penalty_proc
=
tf
.
function
(
rep_penalty_proc
,
jit_compile
=
True
)
top_k_warp
=
tf
.
function
(
top_k_warp
,
jit_compile
=
True
)
top_p_warp
=
tf
.
function
(
top_p_warp
,
jit_compile
=
True
)
# no_repeat_proc = tf.function(no_repeat_proc, jit_compile=True)
no_bad_words_dist_proc
=
tf
.
function
(
no_bad_words_dist_proc
,
jit_compile
=
True
)
# no processor list
# no processor list
scores
=
min_dist_proc
(
input_ids
,
scores
,
cur_len
)
scores
=
min_dist_proc
(
input_ids
,
scores
,
cur_len
)
scores
=
temp_dist_warp
(
input_ids
,
scores
)
scores
=
temp_dist_warp
(
input_ids
,
scores
,
cur_len
)
scores
=
rep_penalty_proc
(
input_ids
,
scores
,
cur_len
)
scores
=
rep_penalty_proc
(
input_ids
,
scores
,
cur_len
)
scores
=
top_k_warp
(
input_ids
,
scores
)
scores
=
top_k_warp
(
input_ids
,
scores
,
cur_len
)
scores
=
top_p_warp
(
input_ids
,
scores
)
scores
=
top_p_warp
(
input_ids
,
scores
,
cur_len
)
scores
=
no_repeat_proc
(
input_ids
,
scores
,
cur_len
)
#
scores = no_repeat_proc(input_ids, scores, cur_len)
scores
=
no_bad_words_dist_proc
(
input_ids
,
scores
,
cur_len
)
scores
=
no_bad_words_dist_proc
(
input_ids
,
scores
,
cur_len
)
# with processor list
# with processor list
...
@@ -361,11 +377,11 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -361,11 +377,11 @@ class TFLogitsProcessorTest(unittest.TestCase):
rep_penalty_proc
,
rep_penalty_proc
,
top_k_warp
,
top_k_warp
,
top_p_warp
,
top_p_warp
,
no_repeat_proc
,
#
no_repeat_proc,
no_bad_words_dist_proc
,
no_bad_words_dist_proc
,
]
]
)
)
scores_comp
=
processor
(
input_ids
,
scores_comp
,
cur_len
=
cur_len
)
scores_comp
=
processor
(
input_ids
,
scores_comp
,
cur_len
)
# remove inf
# remove inf
scores
=
tf
.
where
(
tf
.
math
.
is_inf
(
scores
),
-
1e9
,
scores
)
scores
=
tf
.
where
(
tf
.
math
.
is_inf
(
scores
),
-
1e9
,
scores
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment