Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f0982682
Unverified
Commit
f0982682
authored
Jul 04, 2022
by
Joao Gante
Committed by
GitHub
Jul 04, 2022
Browse files
TF: T5 can now handle a padded past (i.e. XLA generation) (#17969)
* get the right slicing index for position_bias
parent
e3139ad3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
11 deletions
+17
-11
src/transformers/models/t5/modeling_tf_t5.py
src/transformers/models/t5/modeling_tf_t5.py
+13
-3
tests/models/t5/test_modeling_tf_t5.py
tests/models/t5/test_modeling_tf_t5.py
+4
-8
No files found.
src/transformers/models/t5/modeling_tf_t5.py
View file @
f0982682
...
...
@@ -23,6 +23,7 @@ from typing import Optional, Tuple, Union
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.compiler.tf2xla.python.xla
import
dynamic_slice
from
...activations_tf
import
get_tf_activation
from
...modeling_tf_outputs
import
(
...
...
@@ -384,10 +385,19 @@ class TFT5Attention(tf.keras.layers.Layer):
else
:
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
)
# if key and values are already calculated
# we want only the last query position bias
# if key and values are already calculated we want only the last query position bias
if
past_key_value
is
not
None
:
position_bias
=
position_bias
[:,
:,
-
seq_length
:,
:]
if
not
self
.
has_relative_attention_bias
:
position_bias
=
position_bias
[:,
:,
-
seq_length
:,
:]
else
:
# we might have a padded past structure, in which case we want to fetch the position bias slice
# right after the most recently filled past index
most_recently_filled_past_index
=
tf
.
reduce_max
(
tf
.
where
(
past_key_value
[
0
][
0
,
0
,
:,
0
]
!=
0.0
))
position_bias
=
dynamic_slice
(
position_bias
,
(
0
,
0
,
most_recently_filled_past_index
+
1
,
0
),
(
1
,
self
.
n_heads
,
seq_length
,
real_seq_length
),
)
if
mask
is
not
None
:
position_bias
=
tf
.
cast
(
position_bias
,
dtype
=
mask
.
dtype
)
...
...
tests/models/t5/test_modeling_tf_t5.py
View file @
f0982682
...
...
@@ -590,21 +590,17 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
]
input_ids
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
# xla_generate = tf.function(model.generate, jit_compile=True)
xla_generate
=
tf
.
function
(
model
.
generate
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are
# being padded and filled in the right places). This also happens in other generation modes. Investigate.
output_ids
=
model
.
generate
(
input_ids
,
num_beams
=
2
,
max_length
=
9
)
output_ids_xla
=
xla_generate
(
input_ids
,
num_beams
=
2
,
max_length
=
9
)
output_ids
=
model
.
generate
(
input_ids
,
num_beams
=
2
)
output_ids_xla
=
xla_generate
(
input_ids
,
num_beams
=
2
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
expected_output_string
=
[
"Aujourd'hui est une belle journée."
,
"J'ai quatre chats,"
,
"J'ai quatre chats,
trois chiens, deux oiseaux et un cheval.
"
,
]
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment