Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d7f7f29f
Unverified
Commit
d7f7f29f
authored
Apr 12, 2022
by
Joao Gante
Committed by
GitHub
Apr 12, 2022
Browse files
TF: remove set_tensor_by_indices_to_value (#16729)
parent
a315988b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
27 deletions
+15
-27
src/transformers/generation_tf_logits_process.py
src/transformers/generation_tf_logits_process.py
+3
-8
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+8
-9
src/transformers/tf_utils.py
src/transformers/tf_utils.py
+0
-5
tests/generation/test_generation_tf_logits_process.py
tests/generation/test_generation_tf_logits_process.py
+4
-5
No files found.
src/transformers/generation_tf_logits_process.py
View file @
d7f7f29f
...
@@ -19,7 +19,6 @@ from typing import List
...
@@ -19,7 +19,6 @@ from typing import List
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.tf_utils
import
set_tensor_by_indices_to_value
from
.utils
import
add_start_docstrings
from
.utils
import
add_start_docstrings
from
.utils.logging
import
get_logger
from
.utils.logging
import
get_logger
...
@@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
...
@@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
# generate is not XLA - compileable anyways
# generate is not XLA - compileable anyways
if
cur_len
<
self
.
min_length
:
if
cur_len
<
self
.
min_length
:
eos_token_id_mask
=
tf
.
broadcast_to
(
tf
.
range
(
scores
.
shape
[
-
1
])
==
self
.
eos_token_id
,
scores
.
shape
)
eos_token_id_mask
=
tf
.
broadcast_to
(
tf
.
range
(
scores
.
shape
[
-
1
])
==
self
.
eos_token_id
,
scores
.
shape
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
eos_token_id_mask
,
float
(
"-inf"
))
scores
=
tf
.
where
(
eos_token_id_mask
,
float
(
"-inf"
)
,
scores
)
return
scores
return
scores
...
@@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
...
@@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
)
scores
=
set_tensor_by_indices_to_value
(
scores
=
tf
.
where
(
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
),
scores
)
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
)
return
scores
return
scores
...
@@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
...
@@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
)
scores
=
set_tensor_by_indices_to_value
(
scores
=
tf
.
where
(
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
),
scores
)
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
)
return
scores
return
scores
...
...
src/transformers/generation_tf_utils.py
View file @
d7f7f29f
...
@@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
...
@@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper
,
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
TFTopPLogitsWarper
,
)
)
from
.tf_utils
import
set_tensor_by_indices_to_value
,
shape_list
from
.tf_utils
import
shape_list
from
.utils
import
ModelOutput
,
logging
from
.utils
import
ModelOutput
,
logging
...
@@ -952,8 +952,7 @@ class TFGenerationMixin:
...
@@ -952,8 +952,7 @@ class TFGenerationMixin:
[
True
if
token
==
eos_token_id
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
[
True
if
token
==
eos_token_id
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
)
)
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
num_batch_hypotheses
,
vocab_size
])
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
num_batch_hypotheses
,
vocab_size
])
scores
=
tf
.
where
(
eos_token_indices_mask
,
-
float
(
"inf"
),
scores
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
eos_token_indices_mask
,
-
float
(
"inf"
))
if
no_repeat_ngram_size
>
0
:
if
no_repeat_ngram_size
>
0
:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
...
@@ -969,8 +968,8 @@ class TFGenerationMixin:
...
@@ -969,8 +968,8 @@ class TFGenerationMixin:
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
)
scores
=
set_tensor_by_indices_to_valu
e
(
scores
=
tf
.
wher
e
(
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
,
scores
)
)
if
bad_words_ids
is
not
None
:
if
bad_words_ids
is
not
None
:
...
@@ -983,8 +982,8 @@ class TFGenerationMixin:
...
@@ -983,8 +982,8 @@ class TFGenerationMixin:
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
)
scores
=
set_tensor_by_indices_to_valu
e
(
scores
=
tf
.
wher
e
(
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
,
scores
)
)
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
...
@@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
...
@@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits_shape
[
-
1
])
# Safety check
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits_shape
[
-
1
])
# Safety check
# Remove all tokens with a probability less than the last token of the top-k
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)[
0
][...,
-
1
,
None
]
indices_to_remove
=
logits
<
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)[
0
][...,
-
1
,
None
]
logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
filter_value
)
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
if
top_p
<
1.0
:
if
top_p
<
1.0
:
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
sorted_logits
=
tf
.
gather
(
sorted_logits
=
tf
.
gather
(
...
@@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
...
@@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
)
)
# scatter sorted tensors to original indexing
# scatter sorted tensors to original indexing
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
filter_value
)
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
return
logits
return
logits
...
...
src/transformers/tf_utils.py
View file @
d7f7f29f
...
@@ -23,11 +23,6 @@ from .utils import logging
...
@@ -23,11 +23,6 @@ from .utils import logging
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
set_tensor_by_indices_to_value
(
tensor
:
tf
.
Tensor
,
indices
:
tf
.
Tensor
,
value
:
Union
[
tf
.
Tensor
,
int
,
float
]):
# create value_tensor since tensor value assignment is not possible in TF
return
tf
.
where
(
indices
,
value
,
tensor
)
def
shape_list
(
tensor
:
Union
[
tf
.
Tensor
,
np
.
ndarray
])
->
List
[
int
]:
def
shape_list
(
tensor
:
Union
[
tf
.
Tensor
,
np
.
ndarray
])
->
List
[
int
]:
"""
"""
Deal with dynamic shape in tensorflow cleanly.
Deal with dynamic shape in tensorflow cleanly.
...
...
tests/generation/test_generation_tf_logits_process.py
View file @
d7f7f29f
...
@@ -37,7 +37,6 @@ if is_tf_available():
...
@@ -37,7 +37,6 @@ if is_tf_available():
TFTopKLogitsWarper
,
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
TFTopPLogitsWarper
,
)
)
from
transformers.tf_utils
import
set_tensor_by_indices_to_value
from
..test_modeling_tf_common
import
ids_tensor
from
..test_modeling_tf_common
import
ids_tensor
...
@@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores
=
self
.
_get_uniform_logits
(
batch_size
=
2
,
length
=
vocab_size
)
scores
=
self
.
_get_uniform_logits
(
batch_size
=
2
,
length
=
vocab_size
)
mask
=
tf
.
cast
(
tf
.
constant
([[
1
]
+
9
*
[
0
],
10
*
[
0
]]),
tf
.
bool
)
mask
=
tf
.
cast
(
tf
.
constant
([[
1
]
+
9
*
[
0
],
10
*
[
0
]]),
tf
.
bool
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
mask
,
-
1
/
vocab_size
)
scores
=
tf
.
where
(
mask
,
-
1
/
vocab_size
,
scores
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
mask
,
4
/
vocab_size
)
scores
=
tf
.
where
(
mask
,
4
/
vocab_size
,
scores
)
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
...
@@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_comp
=
processor
(
input_ids
,
scores_comp
,
cur_len
=
cur_len
)
scores_comp
=
processor
(
input_ids
,
scores_comp
,
cur_len
=
cur_len
)
# remove inf
# remove inf
scores
=
set_tensor_by_indices_to_value
(
scores
,
tf
.
math
.
is_inf
(
scores
),
-
1e9
)
scores
=
tf
.
where
(
tf
.
math
.
is_inf
(
scores
),
-
1e9
,
scores
)
scores_comp
=
set_tensor_by_indices_to_value
(
scores_comp
,
tf
.
math
.
is_inf
(
scores_comp
),
-
1e9
)
scores_comp
=
tf
.
where
(
tf
.
math
.
is_inf
(
scores_comp
),
-
1e9
,
scores_comp
)
# scores should be equal
# scores should be equal
tf
.
debugging
.
assert_near
(
scores
,
scores_comp
,
atol
=
1e-3
)
tf
.
debugging
.
assert_near
(
scores
,
scores_comp
,
atol
=
1e-3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment