Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
99c8226b
Unverified
Commit
99c8226b
authored
Apr 22, 2022
by
Joao Gante
Committed by
GitHub
Apr 22, 2022
Browse files
TF: XLA repetition penalty (#16879)
parent
ec81c11a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
18 deletions
+43
-18
src/transformers/generation_tf_logits_process.py
src/transformers/generation_tf_logits_process.py
+23
-12
tests/generation/test_generation_tf_logits_process.py
tests/generation/test_generation_tf_logits_process.py
+20
-6
No files found.
src/transformers/generation_tf_logits_process.py
View file @
99c8226b
...
...
@@ -241,18 +241,29 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
self
.
penalty
=
penalty
def
_create_score_penalties
(
self
,
input_ids
,
logits
):
# create logit penalties for already seen input_ids
token_penalties
=
np
.
ones
(
logits
.
shape
)
prev_input_ids
=
[
np
.
unique
(
input_id
)
for
input_id
in
input_ids
.
numpy
()]
for
i
,
prev_input_id
in
enumerate
(
prev_input_ids
):
logit_penalized
=
logits
[
i
].
numpy
()[
prev_input_id
]
logit_penalties
=
np
.
zeros
(
logit_penalized
.
shape
)
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalties
[
logit_penalized
<
0
]
=
self
.
penalty
logit_penalties
[
logit_penalized
>
0
]
=
1
/
self
.
penalty
np
.
put
(
token_penalties
[
i
],
prev_input_id
,
logit_penalties
)
return
tf
.
convert_to_tensor
(
token_penalties
,
dtype
=
tf
.
float32
)
def
_create_score_penalties
(
self
,
input_ids
:
tf
.
Tensor
,
logits
:
tf
.
Tensor
)
->
tf
.
Tensor
:
# We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown
# before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has
# the same token multiple times.
# Gathers the penalties to apply
logit_penalties
=
tf
.
gather
(
logits
,
input_ids
,
axis
=
1
,
batch_dims
=
1
)
logit_penalties
=
tf
.
where
(
logit_penalties
>
0
,
1
/
self
.
penalty
,
logit_penalties
)
logit_penalties
=
tf
.
where
(
logit_penalties
<
0
,
self
.
penalty
,
logit_penalties
)
# Scatters the penalties
token_penalties
=
tf
.
ones
(
logits
.
shape
)
indexable_prev_input_ids
=
tf
.
concat
(
(
tf
.
expand_dims
(
tf
.
repeat
(
tf
.
range
(
input_ids
.
shape
[
0
]),
input_ids
.
shape
[
1
]),
axis
=-
1
),
tf
.
expand_dims
(
tf
.
reshape
(
input_ids
,
[
-
1
]),
axis
=-
1
),
),
axis
=
1
,
)
token_penalties
=
tf
.
tensor_scatter_nd_update
(
token_penalties
,
indices
=
indexable_prev_input_ids
,
updates
=
tf
.
reshape
(
logit_penalties
,
[
-
1
])
)
return
token_penalties
def
__call__
(
self
,
input_ids
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
cur_len
:
int
)
->
tf
.
Tensor
:
score_penalties
=
self
.
_create_score_penalties
(
input_ids
[:,
:
cur_len
],
scores
)
...
...
tests/generation/test_generation_tf_logits_process.py
View file @
99c8226b
...
...
@@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
self
.
assertGreater
(
tf
.
math
.
reduce_max
(
probs
[
1
,
:]),
tf
.
math
.
reduce_max
(
warped_prob_smooth
[
1
,
:]))
self
.
assertLess
(
tf
.
math
.
reduce_min
(
probs
[
1
,
:]),
tf
.
math
.
reduce_min
(
warped_prob_smooth
[
1
,
:]))
def
tes
t_repetition_penalty_
dist_proces
s
(
self
):
def
_ge
t_repetition_penalty_
input
s
(
self
):
vocab_size
=
10
cur_len
=
2
...
...
@@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores
=
tf
.
where
(
mask
,
-
1
/
vocab_size
,
scores
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
scores
=
tf
.
where
(
mask
,
4
/
vocab_size
,
scores
)
return
vocab_size
,
cur_len
,
input_ids
,
scores
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
# check that values were correctly changed
def
_check_repetition_penalty_outputs
(
self
,
scores
,
vocab_size
):
# check that values were correctly changed (negative scores for used tokens should increase, others
# should decrease)
self
.
assertAlmostEqual
(
scores
[
0
,
0
].
numpy
(),
-
(
1
/
vocab_size
)
*
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
1
].
numpy
(),
(
1
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
2
].
numpy
(),
(
1
/
vocab_size
))
# unused tokens should see no change
self
.
assertAlmostEqual
(
scores
[
1
,
0
].
numpy
(),
(
1
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
1
,
5
].
numpy
(),
(
4
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
2
].
numpy
(),
(
1
/
vocab_size
))
# unused tokens should see no change
def
test_repetition_penalty_dist_process
(
self
):
vocab_size
,
cur_len
,
input_ids
,
scores
=
self
.
_get_repetition_penalty_inputs
()
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
self
.
_check_repetition_penalty_outputs
(
scores
,
vocab_size
)
def
test_repetition_penalty_dist_process_xla
(
self
):
vocab_size
,
cur_len
,
input_ids
,
scores
=
self
.
_get_repetition_penalty_inputs
()
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
rep_penalty_proc
=
tf
.
function
(
rep_penalty_proc
,
jit_compile
=
True
)
# added line wrt non-XLA test
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
self
.
_check_repetition_penalty_outputs
(
scores
,
vocab_size
)
def
test_top_k_dist_warper
(
self
):
input_ids
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment