Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9331b379
Unverified
Commit
9331b379
authored
Apr 25, 2022
by
Joao Gante
Committed by
GitHub
Apr 25, 2022
Browse files
TF: XLA Logits Warpers (#16899)
Co-authored-by:
Matt
<
Rocketknight1@users.noreply.github.com
>
parent
809dac48
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
21 deletions
+26
-21
tests/generation/test_generation_tf_logits_process.py
tests/generation/test_generation_tf_logits_process.py
+26
-21
No files found.
tests/generation/test_generation_tf_logits_process.py
View file @
9331b379
...
...
@@ -72,7 +72,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_before_min_length
=
min_dist_processor
(
input_ids
,
scores
,
cur_len
)
self
.
assertFalse
(
tf
.
math
.
reduce_any
(
tf
.
math
.
is_inf
(
scores_before_min_length
)).
numpy
())
def
test_temperature_dist_warper
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_temperature_dist_warper
(
self
,
use_xla
):
input_ids
=
None
length
=
20
...
...
@@ -89,6 +90,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
temp_dist_warper_sharper
=
TFTemperatureLogitsWarper
(
temperature
=
0.5
)
temp_dist_warper_smoother
=
TFTemperatureLogitsWarper
(
temperature
=
1.3
)
if
use_xla
:
temp_dist_warper_sharper
=
tf
.
function
(
temp_dist_warper_sharper
,
jit_compile
=
True
)
temp_dist_warper_smoother
=
tf
.
function
(
temp_dist_warper_smoother
,
jit_compile
=
True
)
warped_prob_sharp
=
tf
.
nn
.
softmax
(
temp_dist_warper_sharper
(
input_ids
,
tf
.
identity
(
scores
)),
axis
=-
1
)
warped_prob_smooth
=
tf
.
nn
.
softmax
(
temp_dist_warper_smoother
(
input_ids
,
tf
.
identity
(
scores
)),
axis
=-
1
)
...
...
@@ -105,7 +109,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
self
.
assertGreater
(
tf
.
math
.
reduce_max
(
probs
[
1
,
:]),
tf
.
math
.
reduce_max
(
warped_prob_smooth
[
1
,
:]))
self
.
assertLess
(
tf
.
math
.
reduce_min
(
probs
[
1
,
:]),
tf
.
math
.
reduce_min
(
warped_prob_smooth
[
1
,
:]))
def
_get_repetition_penalty_inputs
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_repetition_penalty_dist_process
(
self
,
use_xla
):
vocab_size
=
10
cur_len
=
2
...
...
@@ -118,9 +123,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores
=
tf
.
where
(
mask
,
-
1
/
vocab_size
,
scores
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
scores
=
tf
.
where
(
mask
,
4
/
vocab_size
,
scores
)
return
vocab_size
,
cur_len
,
input_ids
,
scores
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
if
use_xla
:
rep_penalty_proc
=
tf
.
function
(
rep_penalty_proc
,
jit_compile
=
True
)
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
def
_check_repetition_penalty_outputs
(
self
,
scores
,
vocab_size
):
# check that values were correctly changed (negative scores for used tokens should increase, others
# should decrease)
self
.
assertAlmostEqual
(
scores
[
0
,
0
].
numpy
(),
-
(
1
/
vocab_size
)
*
2
)
...
...
@@ -131,29 +139,19 @@ class TFLogitsProcessorTest(unittest.TestCase):
self
.
assertAlmostEqual
(
scores
[
1
,
5
].
numpy
(),
(
4
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
2
].
numpy
(),
(
1
/
vocab_size
))
# unused tokens should see no change
def
test_repetition_penalty_dist_process
(
self
):
vocab_size
,
cur_len
,
input_ids
,
scores
=
self
.
_get_repetition_penalty_inputs
()
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
self
.
_check_repetition_penalty_outputs
(
scores
,
vocab_size
)
def
test_repetition_penalty_dist_process_xla
(
self
):
vocab_size
,
cur_len
,
input_ids
,
scores
=
self
.
_get_repetition_penalty_inputs
()
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
rep_penalty_proc
=
tf
.
function
(
rep_penalty_proc
,
jit_compile
=
True
)
# added line wrt non-XLA test
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
self
.
_check_repetition_penalty_outputs
(
scores
,
vocab_size
)
def
test_top_k_dist_warper
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_top_k_dist_warper
(
self
,
use_xla
):
input_ids
=
None
vocab_size
=
10
batch_size
=
2
# create ramp distribution
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
vocab_size
)[
None
,
:]
,
(
batch_size
,
vocab_size
)).
copy
()
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
vocab_size
,
dtype
=
np
.
float32
)
,
(
batch_size
,
vocab_size
)).
copy
()
ramp_logits
[
1
:,
:
vocab_size
//
2
]
=
ramp_logits
[
1
:,
:
vocab_size
//
2
]
+
vocab_size
top_k_warp
=
TFTopKLogitsWarper
(
3
)
if
use_xla
:
top_k_warp
=
tf
.
function
(
top_k_warp
,
jit_compile
=
True
)
scores
=
top_k_warp
(
input_ids
,
ramp_logits
)
...
...
@@ -166,18 +164,21 @@ class TFLogitsProcessorTest(unittest.TestCase):
logits
=
self
.
_get_uniform_logits
(
batch_size
=
batch_size
,
length
=
length
)
top_k_warp_safety_check
=
TFTopKLogitsWarper
(
top_k
=
1
,
filter_value
=
0.0
,
min_tokens_to_keep
=
3
)
if
use_xla
:
top_k_warp_safety_check
=
tf
.
function
(
top_k_warp_safety_check
,
jit_compile
=
True
)
scores
=
top_k_warp_safety_check
(
input_ids
,
logits
)
# uniform dist is not changed
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
0
,
0
])
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
length
)[
None
,
:]
,
(
batch_size
,
length
)).
copy
()
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
length
,
dtype
=
np
.
float32
)
,
(
batch_size
,
length
)).
copy
()
scores
=
top_k_warp_safety_check
(
input_ids
,
ramp_logits
)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
2
,
2
])
def
test_top_p_dist_warper
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_top_p_dist_warper
(
self
,
use_xla
):
input_ids
=
None
vocab_size
=
10
batch_size
=
2
...
...
@@ -186,6 +187,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
dist
=
np
.
log
(
np
.
array
([[
0.3
,
0.1
,
0.1
,
0.5
],
[
0.15
,
0.3
,
0.3
,
0.25
]],
dtype
=
np
.
float32
))
top_p_warp
=
TFTopPLogitsWarper
(
0.7
)
if
use_xla
:
top_p_warp
=
tf
.
function
(
top_p_warp
,
jit_compile
=
True
)
filtered_dist
=
tf
.
exp
(
top_p_warp
(
input_ids
,
dist
))
# dist should be filtered to keep min num values so that sum is >= 0.7
...
...
@@ -203,6 +206,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
# make sure at least 2 tokens are kept
top_p_warp
=
TFTopPLogitsWarper
(
0.9
,
min_tokens_to_keep
=
2
,
filter_value
=
0.0
)
if
use_xla
:
top_p_warp
=
tf
.
function
(
top_p_warp
,
jit_compile
=
True
)
filtered_dist
=
top_p_warp
(
input_ids
,
ramp_logits
)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment