Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d7f7f29f
Unverified
Commit
d7f7f29f
authored
Apr 12, 2022
by
Joao Gante
Committed by
GitHub
Apr 12, 2022
Browse files
TF: remove set_tensor_by_indices_to_value (#16729)
parent
a315988b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
27 deletions
+15
-27
src/transformers/generation_tf_logits_process.py
src/transformers/generation_tf_logits_process.py
+3
-8
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+8
-9
src/transformers/tf_utils.py
src/transformers/tf_utils.py
+0
-5
tests/generation/test_generation_tf_logits_process.py
tests/generation/test_generation_tf_logits_process.py
+4
-5
No files found.
src/transformers/generation_tf_logits_process.py
View file @
d7f7f29f
...
...
@@ -19,7 +19,6 @@ from typing import List
import
numpy
as
np
import
tensorflow
as
tf
from
.tf_utils
import
set_tensor_by_indices_to_value
from
.utils
import
add_start_docstrings
from
.utils.logging
import
get_logger
...
...
@@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
# generate is not XLA - compileable anyways
if
cur_len
<
self
.
min_length
:
eos_token_id_mask
=
tf
.
broadcast_to
(
tf
.
range
(
scores
.
shape
[
-
1
])
==
self
.
eos_token_id
,
scores
.
shape
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
eos_token_id_mask
,
float
(
"-inf"
))
scores
=
tf
.
where
(
eos_token_id_mask
,
float
(
"-inf"
)
,
scores
)
return
scores
...
...
@@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
)
scores
=
tf
.
where
(
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
),
scores
)
return
scores
...
...
@@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
)
scores
=
tf
.
where
(
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
),
scores
)
return
scores
...
...
src/transformers/generation_tf_utils.py
View file @
d7f7f29f
...
...
@@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
)
from
.tf_utils
import
set_tensor_by_indices_to_value
,
shape_list
from
.tf_utils
import
shape_list
from
.utils
import
ModelOutput
,
logging
...
...
@@ -952,8 +952,7 @@ class TFGenerationMixin:
[
True
if
token
==
eos_token_id
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
)
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
num_batch_hypotheses
,
vocab_size
])
scores
=
set_tensor_by_indices_to_value
(
scores
,
eos_token_indices_mask
,
-
float
(
"inf"
))
scores
=
tf
.
where
(
eos_token_indices_mask
,
-
float
(
"inf"
),
scores
)
if
no_repeat_ngram_size
>
0
:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
...
...
@@ -969,8 +968,8 @@ class TFGenerationMixin:
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
scores
=
set_tensor_by_indices_to_valu
e
(
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
scores
=
tf
.
wher
e
(
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
,
scores
)
if
bad_words_ids
is
not
None
:
...
...
@@ -983,8 +982,8 @@ class TFGenerationMixin:
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
scores
=
set_tensor_by_indices_to_valu
e
(
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
scores
=
tf
.
wher
e
(
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
,
scores
)
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
...
...
@@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits_shape
[
-
1
])
# Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)[
0
][...,
-
1
,
None
]
logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
filter_value
)
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
if
top_p
<
1.0
:
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
sorted_logits
=
tf
.
gather
(
...
...
@@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
)
# scatter sorted tensors to original indexing
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
filter_value
)
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
return
logits
...
...
src/transformers/tf_utils.py
View file @
d7f7f29f
...
...
@@ -23,11 +23,6 @@ from .utils import logging
logger
=
logging
.
get_logger
(
__name__
)
def
set_tensor_by_indices_to_value
(
tensor
:
tf
.
Tensor
,
indices
:
tf
.
Tensor
,
value
:
Union
[
tf
.
Tensor
,
int
,
float
]):
# create value_tensor since tensor value assignment is not possible in TF
return
tf
.
where
(
indices
,
value
,
tensor
)
def
shape_list
(
tensor
:
Union
[
tf
.
Tensor
,
np
.
ndarray
])
->
List
[
int
]:
"""
Deal with dynamic shape in tensorflow cleanly.
...
...
tests/generation/test_generation_tf_logits_process.py
View file @
d7f7f29f
...
...
@@ -37,7 +37,6 @@ if is_tf_available():
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
)
from
transformers.tf_utils
import
set_tensor_by_indices_to_value
from
..test_modeling_tf_common
import
ids_tensor
...
...
@@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores
=
self
.
_get_uniform_logits
(
batch_size
=
2
,
length
=
vocab_size
)
mask
=
tf
.
cast
(
tf
.
constant
([[
1
]
+
9
*
[
0
],
10
*
[
0
]]),
tf
.
bool
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
mask
,
-
1
/
vocab_size
)
scores
=
tf
.
where
(
mask
,
-
1
/
vocab_size
,
scores
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
mask
,
4
/
vocab_size
)
scores
=
tf
.
where
(
mask
,
4
/
vocab_size
,
scores
)
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
...
...
@@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_comp
=
processor
(
input_ids
,
scores_comp
,
cur_len
=
cur_len
)
# remove inf
scores
=
set_tensor_by_indices_to_value
(
scores
,
tf
.
math
.
is_inf
(
scores
),
-
1e9
)
scores_comp
=
set_tensor_by_indices_to_value
(
scores_comp
,
tf
.
math
.
is_inf
(
scores_comp
),
-
1e9
)
scores
=
tf
.
where
(
tf
.
math
.
is_inf
(
scores
),
-
1e9
,
scores
)
scores_comp
=
tf
.
where
(
tf
.
math
.
is_inf
(
scores_comp
),
-
1e9
,
scores_comp
)
# scores should be equal
tf
.
debugging
.
assert_near
(
scores
,
scores_comp
,
atol
=
1e-3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment