Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9331b379
Unverified
Commit
9331b379
authored
Apr 25, 2022
by
Joao Gante
Committed by
GitHub
Apr 25, 2022
Browse files
TF: XLA Logits Warpers (#16899)
Co-authored-by:
Matt
<
Rocketknight1@users.noreply.github.com
>
parent
809dac48
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
21 deletions
+26
-21
tests/generation/test_generation_tf_logits_process.py
tests/generation/test_generation_tf_logits_process.py
+26
-21
No files found.
tests/generation/test_generation_tf_logits_process.py
View file @
9331b379
...
@@ -72,7 +72,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -72,7 +72,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_before_min_length
=
min_dist_processor
(
input_ids
,
scores
,
cur_len
)
scores_before_min_length
=
min_dist_processor
(
input_ids
,
scores
,
cur_len
)
self
.
assertFalse
(
tf
.
math
.
reduce_any
(
tf
.
math
.
is_inf
(
scores_before_min_length
)).
numpy
())
self
.
assertFalse
(
tf
.
math
.
reduce_any
(
tf
.
math
.
is_inf
(
scores_before_min_length
)).
numpy
())
def
test_temperature_dist_warper
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_temperature_dist_warper
(
self
,
use_xla
):
input_ids
=
None
input_ids
=
None
length
=
20
length
=
20
...
@@ -89,6 +90,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -89,6 +90,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
temp_dist_warper_sharper
=
TFTemperatureLogitsWarper
(
temperature
=
0.5
)
temp_dist_warper_sharper
=
TFTemperatureLogitsWarper
(
temperature
=
0.5
)
temp_dist_warper_smoother
=
TFTemperatureLogitsWarper
(
temperature
=
1.3
)
temp_dist_warper_smoother
=
TFTemperatureLogitsWarper
(
temperature
=
1.3
)
if
use_xla
:
temp_dist_warper_sharper
=
tf
.
function
(
temp_dist_warper_sharper
,
jit_compile
=
True
)
temp_dist_warper_smoother
=
tf
.
function
(
temp_dist_warper_smoother
,
jit_compile
=
True
)
warped_prob_sharp
=
tf
.
nn
.
softmax
(
temp_dist_warper_sharper
(
input_ids
,
tf
.
identity
(
scores
)),
axis
=-
1
)
warped_prob_sharp
=
tf
.
nn
.
softmax
(
temp_dist_warper_sharper
(
input_ids
,
tf
.
identity
(
scores
)),
axis
=-
1
)
warped_prob_smooth
=
tf
.
nn
.
softmax
(
temp_dist_warper_smoother
(
input_ids
,
tf
.
identity
(
scores
)),
axis
=-
1
)
warped_prob_smooth
=
tf
.
nn
.
softmax
(
temp_dist_warper_smoother
(
input_ids
,
tf
.
identity
(
scores
)),
axis
=-
1
)
...
@@ -105,7 +109,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -105,7 +109,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
self
.
assertGreater
(
tf
.
math
.
reduce_max
(
probs
[
1
,
:]),
tf
.
math
.
reduce_max
(
warped_prob_smooth
[
1
,
:]))
self
.
assertGreater
(
tf
.
math
.
reduce_max
(
probs
[
1
,
:]),
tf
.
math
.
reduce_max
(
warped_prob_smooth
[
1
,
:]))
self
.
assertLess
(
tf
.
math
.
reduce_min
(
probs
[
1
,
:]),
tf
.
math
.
reduce_min
(
warped_prob_smooth
[
1
,
:]))
self
.
assertLess
(
tf
.
math
.
reduce_min
(
probs
[
1
,
:]),
tf
.
math
.
reduce_min
(
warped_prob_smooth
[
1
,
:]))
def
_get_repetition_penalty_inputs
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_repetition_penalty_dist_process
(
self
,
use_xla
):
vocab_size
=
10
vocab_size
=
10
cur_len
=
2
cur_len
=
2
...
@@ -118,9 +123,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -118,9 +123,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores
=
tf
.
where
(
mask
,
-
1
/
vocab_size
,
scores
)
scores
=
tf
.
where
(
mask
,
-
1
/
vocab_size
,
scores
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
scores
=
tf
.
where
(
mask
,
4
/
vocab_size
,
scores
)
scores
=
tf
.
where
(
mask
,
4
/
vocab_size
,
scores
)
return
vocab_size
,
cur_len
,
input_ids
,
scores
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
if
use_xla
:
rep_penalty_proc
=
tf
.
function
(
rep_penalty_proc
,
jit_compile
=
True
)
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
def
_check_repetition_penalty_outputs
(
self
,
scores
,
vocab_size
):
# check that values were correctly changed (negative scores for used tokens should increase, others
# check that values were correctly changed (negative scores for used tokens should increase, others
# should decrease)
# should decrease)
self
.
assertAlmostEqual
(
scores
[
0
,
0
].
numpy
(),
-
(
1
/
vocab_size
)
*
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
0
].
numpy
(),
-
(
1
/
vocab_size
)
*
2
)
...
@@ -131,29 +139,19 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -131,29 +139,19 @@ class TFLogitsProcessorTest(unittest.TestCase):
self
.
assertAlmostEqual
(
scores
[
1
,
5
].
numpy
(),
(
4
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
1
,
5
].
numpy
(),
(
4
/
vocab_size
)
/
2
)
self
.
assertAlmostEqual
(
scores
[
0
,
2
].
numpy
(),
(
1
/
vocab_size
))
# unused tokens should see no change
self
.
assertAlmostEqual
(
scores
[
0
,
2
].
numpy
(),
(
1
/
vocab_size
))
# unused tokens should see no change
def
test_repetition_penalty_dist_process
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
vocab_size
,
cur_len
,
input_ids
,
scores
=
self
.
_get_repetition_penalty_inputs
()
def
test_top_k_dist_warper
(
self
,
use_xla
):
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
self
.
_check_repetition_penalty_outputs
(
scores
,
vocab_size
)
def
test_repetition_penalty_dist_process_xla
(
self
):
vocab_size
,
cur_len
,
input_ids
,
scores
=
self
.
_get_repetition_penalty_inputs
()
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
rep_penalty_proc
=
tf
.
function
(
rep_penalty_proc
,
jit_compile
=
True
)
# added line wrt non-XLA test
scores
=
rep_penalty_proc
(
input_ids
,
tf
.
identity
(
scores
),
cur_len
)
self
.
_check_repetition_penalty_outputs
(
scores
,
vocab_size
)
def
test_top_k_dist_warper
(
self
):
input_ids
=
None
input_ids
=
None
vocab_size
=
10
vocab_size
=
10
batch_size
=
2
batch_size
=
2
# create ramp distribution
# create ramp distribution
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
vocab_size
)[
None
,
:]
,
(
batch_size
,
vocab_size
)).
copy
()
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
vocab_size
,
dtype
=
np
.
float32
)
,
(
batch_size
,
vocab_size
)).
copy
()
ramp_logits
[
1
:,
:
vocab_size
//
2
]
=
ramp_logits
[
1
:,
:
vocab_size
//
2
]
+
vocab_size
ramp_logits
[
1
:,
:
vocab_size
//
2
]
=
ramp_logits
[
1
:,
:
vocab_size
//
2
]
+
vocab_size
top_k_warp
=
TFTopKLogitsWarper
(
3
)
top_k_warp
=
TFTopKLogitsWarper
(
3
)
if
use_xla
:
top_k_warp
=
tf
.
function
(
top_k_warp
,
jit_compile
=
True
)
scores
=
top_k_warp
(
input_ids
,
ramp_logits
)
scores
=
top_k_warp
(
input_ids
,
ramp_logits
)
...
@@ -166,18 +164,21 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -166,18 +164,21 @@ class TFLogitsProcessorTest(unittest.TestCase):
logits
=
self
.
_get_uniform_logits
(
batch_size
=
batch_size
,
length
=
length
)
logits
=
self
.
_get_uniform_logits
(
batch_size
=
batch_size
,
length
=
length
)
top_k_warp_safety_check
=
TFTopKLogitsWarper
(
top_k
=
1
,
filter_value
=
0.0
,
min_tokens_to_keep
=
3
)
top_k_warp_safety_check
=
TFTopKLogitsWarper
(
top_k
=
1
,
filter_value
=
0.0
,
min_tokens_to_keep
=
3
)
if
use_xla
:
top_k_warp_safety_check
=
tf
.
function
(
top_k_warp_safety_check
,
jit_compile
=
True
)
scores
=
top_k_warp_safety_check
(
input_ids
,
logits
)
scores
=
top_k_warp_safety_check
(
input_ids
,
logits
)
# uniform dist is not changed
# uniform dist is not changed
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
0
,
0
])
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
0
,
0
])
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
length
)[
None
,
:]
,
(
batch_size
,
length
)).
copy
()
ramp_logits
=
np
.
broadcast_to
(
np
.
arange
(
length
,
dtype
=
np
.
float32
)
,
(
batch_size
,
length
)).
copy
()
scores
=
top_k_warp_safety_check
(
input_ids
,
ramp_logits
)
scores
=
top_k_warp_safety_check
(
input_ids
,
ramp_logits
)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
2
,
2
])
self
.
assertListEqual
(
tf
.
math
.
reduce_sum
(
tf
.
where
(
scores
==
0.0
,
1
,
0
),
axis
=-
1
).
numpy
().
tolist
(),
[
2
,
2
])
def
test_top_p_dist_warper
(
self
):
@
parameterized
.
expand
([(
False
,),
(
True
,)])
def
test_top_p_dist_warper
(
self
,
use_xla
):
input_ids
=
None
input_ids
=
None
vocab_size
=
10
vocab_size
=
10
batch_size
=
2
batch_size
=
2
...
@@ -186,6 +187,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -186,6 +187,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
dist
=
np
.
log
(
np
.
array
([[
0.3
,
0.1
,
0.1
,
0.5
],
[
0.15
,
0.3
,
0.3
,
0.25
]],
dtype
=
np
.
float32
))
dist
=
np
.
log
(
np
.
array
([[
0.3
,
0.1
,
0.1
,
0.5
],
[
0.15
,
0.3
,
0.3
,
0.25
]],
dtype
=
np
.
float32
))
top_p_warp
=
TFTopPLogitsWarper
(
0.7
)
top_p_warp
=
TFTopPLogitsWarper
(
0.7
)
if
use_xla
:
top_p_warp
=
tf
.
function
(
top_p_warp
,
jit_compile
=
True
)
filtered_dist
=
tf
.
exp
(
top_p_warp
(
input_ids
,
dist
))
filtered_dist
=
tf
.
exp
(
top_p_warp
(
input_ids
,
dist
))
# dist should be filtered to keep min num values so that sum is >= 0.7
# dist should be filtered to keep min num values so that sum is >= 0.7
...
@@ -203,6 +206,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -203,6 +206,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
# make sure at least 2 tokens are kept
# make sure at least 2 tokens are kept
top_p_warp
=
TFTopPLogitsWarper
(
0.9
,
min_tokens_to_keep
=
2
,
filter_value
=
0.0
)
top_p_warp
=
TFTopPLogitsWarper
(
0.9
,
min_tokens_to_keep
=
2
,
filter_value
=
0.0
)
if
use_xla
:
top_p_warp
=
tf
.
function
(
top_p_warp
,
jit_compile
=
True
)
filtered_dist
=
top_p_warp
(
input_ids
,
ramp_logits
)
filtered_dist
=
top_p_warp
(
input_ids
,
ramp_logits
)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment