Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f0982682
Unverified
Commit
f0982682
authored
Jul 04, 2022
by
Joao Gante
Committed by
GitHub
Jul 04, 2022
Browse files
TF: T5 can now handle a padded past (i.e. XLA generation) (#17969)
* get the right slicing index for position_bias
parent
e3139ad3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
11 deletions
+17
-11
src/transformers/models/t5/modeling_tf_t5.py
src/transformers/models/t5/modeling_tf_t5.py
+13
-3
tests/models/t5/test_modeling_tf_t5.py
tests/models/t5/test_modeling_tf_t5.py
+4
-8
No files found.
src/transformers/models/t5/modeling_tf_t5.py
View file @
f0982682
...
...
@@ -23,6 +23,7 @@ from typing import Optional, Tuple, Union
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.compiler.tf2xla.python.xla
import
dynamic_slice
from
...activations_tf
import
get_tf_activation
from
...modeling_tf_outputs
import
(
...
...
@@ -384,10 +385,19 @@ class TFT5Attention(tf.keras.layers.Layer):
else
:
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
)
# if key and values are already calculated
# we want only the last query position bias
# if key and values are already calculated we want only the last query position bias
if
past_key_value
is
not
None
:
if
not
self
.
has_relative_attention_bias
:
position_bias
=
position_bias
[:,
:,
-
seq_length
:,
:]
else
:
# we might have a padded past structure, in which case we want to fetch the position bias slice
# right after the most recently filled past index
most_recently_filled_past_index
=
tf
.
reduce_max
(
tf
.
where
(
past_key_value
[
0
][
0
,
0
,
:,
0
]
!=
0.0
))
position_bias
=
dynamic_slice
(
position_bias
,
(
0
,
0
,
most_recently_filled_past_index
+
1
,
0
),
(
1
,
self
.
n_heads
,
seq_length
,
real_seq_length
),
)
if
mask
is
not
None
:
position_bias
=
tf
.
cast
(
position_bias
,
dtype
=
mask
.
dtype
)
...
...
tests/models/t5/test_modeling_tf_t5.py
View file @
f0982682
...
...
@@ -590,21 +590,17 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
]
input_ids
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
# xla_generate = tf.function(model.generate, jit_compile=True)
xla_generate
=
tf
.
function
(
model
.
generate
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are
# being padded and filled in the right places). This also happens in other generation modes. Investigate.
output_ids
=
model
.
generate
(
input_ids
,
num_beams
=
2
,
max_length
=
9
)
output_ids_xla
=
xla_generate
(
input_ids
,
num_beams
=
2
,
max_length
=
9
)
output_ids
=
model
.
generate
(
input_ids
,
num_beams
=
2
)
output_ids_xla
=
xla_generate
(
input_ids
,
num_beams
=
2
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
expected_output_string
=
[
"Aujourd'hui est une belle journée."
,
"J'ai quatre chats,"
,
"J'ai quatre chats,
trois chiens, deux oiseaux et un cheval.
"
,
]
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment