Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
99c8226b
Unverified
Commit
99c8226b
authored
Apr 22, 2022
by
Joao Gante
Committed by
GitHub
Apr 22, 2022
Browse files
TF: XLA repetition penalty (#16879)
parent
ec81c11a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
18 deletions
+43
-18
src/transformers/generation_tf_logits_process.py
src/transformers/generation_tf_logits_process.py
+23
-12
tests/generation/test_generation_tf_logits_process.py
tests/generation/test_generation_tf_logits_process.py
+20
-6
No files found.
src/transformers/generation_tf_logits_process.py
View file @
99c8226b
...
@@ -241,18 +241,29 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
...
@@ -241,18 +241,29 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
self
.
penalty
=
penalty
self
.
penalty
=
penalty
def
_create_score_penalties
(
self
,
input_ids
,
logits
):
def
_create_score_penalties
(
self
,
input_ids
:
tf
.
Tensor
,
logits
:
tf
.
Tensor
)
->
tf
.
Tensor
:
# create logit penalties for already seen input_ids
# We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown
token_penalties
=
np
.
ones
(
logits
.
shape
)
# before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has
prev_input_ids
=
[
np
.
unique
(
input_id
)
for
input_id
in
input_ids
.
numpy
()]
# the same token multiple times.
for
i
,
prev_input_id
in
enumerate
(
prev_input_ids
):
logit_penalized
=
logits
[
i
].
numpy
()[
prev_input_id
]
# Gathers the penalties to apply
logit_penalties
=
np
.
zeros
(
logit_penalized
.
shape
)
logit_penalties
=
tf
.
gather
(
logits
,
input_ids
,
axis
=
1
,
batch_dims
=
1
)
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalties
=
tf
.
where
(
logit_penalties
>
0
,
1
/
self
.
penalty
,
logit_penalties
)
logit_penalties
[
logit_penalized
<
0
]
=
self
.
penalty
logit_penalties
=
tf
.
where
(
logit_penalties
<
0
,
self
.
penalty
,
logit_penalties
)
logit_penalties
[
logit_penalized
>
0
]
=
1
/
self
.
penalty
np
.
put
(
token_penalties
[
i
],
prev_input_id
,
logit_penalties
)
# Scatters the penalties
return
tf
.
convert_to_tensor
(
token_penalties
,
dtype
=
tf
.
float32
)
token_penalties
=
tf
.
ones
(
logits
.
shape
)
indexable_prev_input_ids
=
tf
.
concat
(
(
tf
.
expand_dims
(
tf
.
repeat
(
tf
.
range
(
input_ids
.
shape
[
0
]),
input_ids
.
shape
[
1
]),
axis
=-
1
),
tf
.
expand_dims
(
tf
.
reshape
(
input_ids
,
[
-
1
]),
axis
=-
1
),
),
axis
=
1
,
)
token_penalties
=
tf
.
tensor_scatter_nd_update
(
token_penalties
,
indices
=
indexable_prev_input_ids
,
updates
=
tf
.
reshape
(
logit_penalties
,
[
-
1
])
)
return
token_penalties
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
score_penalties
=
self
.
_create_score_penalties
(
input_ids
[:,
:
cur_len
],
scores
)
score_penalties
=
self
.
_create_score_penalties
(
input_ids
[:,
:
cur_len
],
scores
)
...
...
tests/generation/test_generation_tf_logits_process.py
View file @
99c8226b
...
@@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
self
.
assertGreater
(
tf
.
math
.
reduce_max
(
probs
[
1
,
:]),
tf
.
math
.
reduce_max
(
warped_prob_smooth
[
1
,
:]))
self
.
assertGreater
(
tf
.
math
.
reduce_max
(
probs
[
1
,
:]),
tf
.
math
.
reduce_max
(
warped_prob_smooth
[
1
,
:]))
self
.
assertLess
(
tf
.
math
.
reduce_min
(
probs
[
1
,
:]),
tf
.
math
.
reduce_min
(
warped_prob_smooth
[
1
,
:]))
self
.
assertLess
(
tf
.
math
.
reduce_min
(
probs
[
1
,
:]),
tf
.
math
.
reduce_min
(
warped_prob_smooth
[
1
,
:]))
def
tes
t_repetition_penalty_
dist_proces
s
(
self
):
def
_ge
t_repetition_penalty_
input
s
(
self
):
vocab_size
=
10
vocab_size
=
10
cur_len
=
2
cur_len
=
2
...
@@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores
=
tf
.
where
(
mask
,
-
1
/
vocab_size
,
scores
)
scores
=
tf
.
where
(
mask
,
-
1
/
vocab_size
,
scores
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
scores
=
tf
.
where
(
mask
,
4
/
vocab_size
,
scores
)
scores
=
tf
.
where
(
mask
,
4
/
vocab_size
,
scores
)
return
vocab_size
,
cur_len
,
input_ids
,
scores
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
def
_check_repetition_penalty_outputs
(
self
,
scores
,
vocab_size
):
# check that values were correctly changed (negative scores for used tokens should increase, others
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
# should decrease)
# check that values were correctly changed
self
.
assertAlmostEqual
(
scores
[
0
,
0
].
numpy
(),
-
(
1
/
vocab_size
)
*
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
0
].
numpy
(),
-
(
1
/
vocab_size
)
*
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
1
].
numpy
(),
(
1
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
1
].
numpy
(),
(
1
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
2
].
numpy
(),
(
1
/
vocab_size
))
# unused tokens should see no change
self
.
assertAlmostEqual
(
scores
[
1
,
0
].
numpy
(),
(
1
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
1
,
0
].
numpy
(),
(
1
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
1
,
5
].
numpy
(),
(
4
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
1
,
5
].
numpy
(),
(
4
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
2
].
numpy
(),
(
1
/
vocab_size
))
# unused tokens should see no change
def
test_repetition_penalty_dist_process
(
self
):
vocab_size
,
cur_len
,
input_ids
,
scores
=
self
.
_get_repetition_penalty_inputs
()
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
self
.
_check_repetition_penalty_outputs
(
scores
,
vocab_size
)
def
test_repetition_penalty_dist_process_xla
(
self
):
vocab_size
,
cur_len
,
input_ids
,
scores
=
self
.
_get_repetition_penalty_inputs
()
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
rep_penalty_proc
=
tf
.
function
(
rep_penalty_proc
,
jit_compile
=
True
)
# added line wrt non-XLA test
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
self
.
_check_repetition_penalty_outputs
(
scores
,
vocab_size
)
def
test_top_k_dist_warper
(
self
):
def
test_top_k_dist_warper
(
self
):
input_ids
=
None
input_ids
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment