Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d7f7f29f
"docs/vscode:/vscode.git/clone" did not exist on "0842c33edd5df349daddcbedb914d63e342d3c3d"
Unverified
Commit
d7f7f29f
authored
Apr 12, 2022
by
Joao Gante
Committed by
GitHub
Apr 12, 2022
Browse files
TF: remove set_tensor_by_indices_to_value (#16729)
parent
a315988b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
27 deletions
+15
-27
src/transformers/generation_tf_logits_process.py
src/transformers/generation_tf_logits_process.py
+3
-8
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+8
-9
src/transformers/tf_utils.py
src/transformers/tf_utils.py
+0
-5
tests/generation/test_generation_tf_logits_process.py
tests/generation/test_generation_tf_logits_process.py
+4
-5
No files found.
src/transformers/generation_tf_logits_process.py
View file @
d7f7f29f
...
@@ -19,7 +19,6 @@ from typing import List
...
@@ -19,7 +19,6 @@ from typing import List
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.tf_utils
import
set_tensor_by_indices_to_value
from
.utils
import
add_start_docstrings
from
.utils
import
add_start_docstrings
from
.utils.logging
import
get_logger
from
.utils.logging
import
get_logger
...
@@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
...
@@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
# generate is not XLA - compileable anyways
# generate is not XLA - compileable anyways
if
cur_len
<
self
.
min_length
:
if
cur_len
<
self
.
min_length
:
eos_token_id_mask
=
tf
.
broadcast_to
(
tf
.
range
(
scores
.
shape
[
-
1
])
==
self
.
eos_token_id
,
scores
.
shape
)
eos_token_id_mask
=
tf
.
broadcast_to
(
tf
.
range
(
scores
.
shape
[
-
1
])
==
self
.
eos_token_id
,
scores
.
shape
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
eos_token_id_mask
,
float
(
"-inf"
))
scores
=
tf
.
where
(
eos_token_id_mask
,
float
(
"-inf"
)
,
scores
)
return
scores
return
scores
...
@@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
...
@@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
)
scores
=
set_tensor_by_indices_to_value
(
scores
=
tf
.
where
(
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
),
scores
)
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
)
return
scores
return
scores
...
@@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
...
@@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
)
scores
=
set_tensor_by_indices_to_value
(
scores
=
tf
.
where
(
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
),
scores
)
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
)
return
scores
return
scores
...
...
src/transformers/generation_tf_utils.py
View file @
d7f7f29f
...
@@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
...
@@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper
,
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
TFTopPLogitsWarper
,
)
)
from
.tf_utils
import
set_tensor_by_indices_to_value
,
shape_list
from
.tf_utils
import
shape_list
from
.utils
import
ModelOutput
,
logging
from
.utils
import
ModelOutput
,
logging
...
@@ -952,8 +952,7 @@ class TFGenerationMixin:
...
@@ -952,8 +952,7 @@ class TFGenerationMixin:
[
True
if
token
==
eos_token_id
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
[
True
if
token
==
eos_token_id
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
)
)
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
num_batch_hypotheses
,
vocab_size
])
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
num_batch_hypotheses
,
vocab_size
])
scores
=
tf
.
where
(
eos_token_indices_mask
,
-
float
(
"inf"
),
scores
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
eos_token_indices_mask
,
-
float
(
"inf"
))
if
no_repeat_ngram_size
>
0
:
if
no_repeat_ngram_size
>
0
:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
...
@@ -969,8 +968,8 @@ class TFGenerationMixin:
...
@@ -969,8 +968,8 @@ class TFGenerationMixin:
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
)
scores
=
set_tensor_by_indices_to_valu
e
(
scores
=
tf
.
wher
e
(
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
,
scores
)
)
if
bad_words_ids
is
not
None
:
if
bad_words_ids
is
not
None
:
...
@@ -983,8 +982,8 @@ class TFGenerationMixin:
...
@@ -983,8 +982,8 @@ class TFGenerationMixin:
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
)
scores
=
set_tensor_by_indices_to_valu
e
(
scores
=
tf
.
wher
e
(
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
,
scores
)
)
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
...
@@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
...
@@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits_shape
[
-
1
])
# Safety check
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits_shape
[
-
1
])
# Safety check
# Remove all tokens with a probability less than the last token of the top-k
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)[
0
][...,
-
1
,
None
]
indices_to_remove
=
logits
<
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)[
0
][...,
-
1
,
None
]
logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
filter_value
)
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
if
top_p
<
1.0
:
if
top_p
<
1.0
:
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
sorted_logits
=
tf
.
gather
(
sorted_logits
=
tf
.
gather
(
...
@@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
...
@@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
)
)
# scatter sorted tensors to original indexing
# scatter sorted tensors to original indexing
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
filter_value
)
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
return
logits
return
logits
...
...
src/transformers/tf_utils.py
View file @
d7f7f29f
...
@@ -23,11 +23,6 @@ from .utils import logging
...
@@ -23,11 +23,6 @@ from .utils import logging
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
set_tensor_by_indices_to_value
(
tensor
:
tf
.
Tensor
,
indices
:
tf
.
Tensor
,
value
:
Union
[
tf
.
Tensor
,
int
,
float
]):
# create value_tensor since tensor value assignment is not possible in TF
return
tf
.
where
(
indices
,
value
,
tensor
)
def
shape_list
(
tensor
:
Union
[
tf
.
Tensor
,
np
.
ndarray
])
->
List
[
int
]:
def
shape_list
(
tensor
:
Union
[
tf
.
Tensor
,
np
.
ndarray
])
->
List
[
int
]:
"""
"""
Deal with dynamic shape in tensorflow cleanly.
Deal with dynamic shape in tensorflow cleanly.
...
...
tests/generation/test_generation_tf_logits_process.py
View file @
d7f7f29f
...
@@ -37,7 +37,6 @@ if is_tf_available():
...
@@ -37,7 +37,6 @@ if is_tf_available():
TFTopKLogitsWarper
,
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
TFTopPLogitsWarper
,
)
)
from
transformers.tf_utils
import
set_tensor_by_indices_to_value
from
..test_modeling_tf_common
import
ids_tensor
from
..test_modeling_tf_common
import
ids_tensor
...
@@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores
=
self
.
_get_uniform_logits
(
batch_size
=
2
,
length
=
vocab_size
)
scores
=
self
.
_get_uniform_logits
(
batch_size
=
2
,
length
=
vocab_size
)
mask
=
tf
.
cast
(
tf
.
constant
([[
1
]
+
9
*
[
0
],
10
*
[
0
]]),
tf
.
bool
)
mask
=
tf
.
cast
(
tf
.
constant
([[
1
]
+
9
*
[
0
],
10
*
[
0
]]),
tf
.
bool
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
mask
,
-
1
/
vocab_size
)
scores
=
tf
.
where
(
mask
,
-
1
/
vocab_size
,
scores
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
mask
=
tf
.
cast
(
tf
.
constant
([
10
*
[
0
],
5
*
[
0
]
+
[
1
]
+
4
*
[
0
]]),
tf
.
bool
)
scores
=
set_tensor_by_indices_to_value
(
scores
,
mask
,
4
/
vocab_size
)
scores
=
tf
.
where
(
mask
,
4
/
vocab_size
,
scores
)
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
rep_penalty_proc
=
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
2.0
)
...
@@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
...
@@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_comp
=
processor
(
input_ids
,
scores_comp
,
cur_len
=
cur_len
)
scores_comp
=
processor
(
input_ids
,
scores_comp
,
cur_len
=
cur_len
)
# remove inf
# remove inf
scores
=
set_tensor_by_indices_to_value
(
scores
,
tf
.
math
.
is_inf
(
scores
),
-
1e9
)
scores
=
tf
.
where
(
tf
.
math
.
is_inf
(
scores
),
-
1e9
,
scores
)
scores_comp
=
set_tensor_by_indices_to_value
(
scores_comp
,
tf
.
math
.
is_inf
(
scores_comp
),
-
1e9
)
scores_comp
=
tf
.
where
(
tf
.
math
.
is_inf
(
scores_comp
),
-
1e9
,
scores_comp
)
# scores should be equal
# scores should be equal
tf
.
debugging
.
assert_near
(
scores
,
scores_comp
,
atol
=
1e-3
)
tf
.
debugging
.
assert_near
(
scores
,
scores_comp
,
atol
=
1e-3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment