Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d97d06d0
Unverified
Commit
d97d06d0
authored
Dec 28, 2020
by
Julien Plu
Committed by
GitHub
Dec 28, 2020
Browse files
Fix TF T5 (#9301)
* Fix T5 * Fix test * Fix test
parent
83fdd252
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
5 deletions
+8
-5
src/transformers/models/t5/modeling_tf_t5.py
src/transformers/models/t5/modeling_tf_t5.py
+8
-5
No files found.
src/transformers/models/t5/modeling_tf_t5.py
View file @
d97d06d0
...
@@ -268,9 +268,9 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -268,9 +268,9 @@ class TFT5Attention(tf.keras.layers.Layer):
),
"past_key_value should have 2 past states: keys and values. Got {} past states"
.
format
(
),
"past_key_value should have 2 past states: keys and values. Got {} past states"
.
format
(
len
(
past_key_value
)
len
(
past_key_value
)
)
)
real_seq_length
+=
past_key_value
[
0
]
.
shape
[
2
]
if
query_length
is
None
else
query_length
real_seq_length
+=
shape_list
(
past_key_value
[
0
]
)
[
2
]
if
query_length
is
None
else
query_length
key_length
=
real_seq_length
if
key_value_states
is
None
else
key_value_states
.
shape
[
1
]
key_length
=
real_seq_length
if
key_value_states
is
None
else
shape_list
(
key_value_states
)
[
1
]
def
shape
(
hidden_states
):
def
shape
(
hidden_states
):
""" projection """
""" projection """
...
@@ -1147,13 +1147,14 @@ class TFT5Model(TFT5PreTrainedModel):
...
@@ -1147,13 +1147,14 @@ class TFT5Model(TFT5PreTrainedModel):
training
=
inputs
[
"training"
],
training
=
inputs
[
"training"
],
)
)
past
=
(
inputs
[
"encoder_outputs"
],
decoder_outputs
[
1
])
if
inputs
[
"use_cache"
]
else
None
if
not
inputs
[
"return_dict"
]:
if
not
inputs
[
"return_dict"
]:
past
=
(
inputs
[
"encoder_outputs"
],
decoder_outputs
[
1
])
if
inputs
[
"use_cache"
]
else
None
if
past
is
not
None
:
if
past
is
not
None
:
decoder_outputs
=
decoder_outputs
[:
1
]
+
(
past
,)
+
decoder_outputs
[
2
:]
decoder_outputs
=
decoder_outputs
[:
1
]
+
(
past
,)
+
decoder_outputs
[
2
:]
return
decoder_outputs
+
inputs
[
"encoder_outputs"
]
return
decoder_outputs
+
inputs
[
"encoder_outputs"
]
past
=
(
inputs
[
"encoder_outputs"
].
to_tuple
(),
decoder_outputs
[
1
])
if
inputs
[
"use_cache"
]
else
None
return
TFSeq2SeqModelOutput
(
return
TFSeq2SeqModelOutput
(
last_hidden_state
=
decoder_outputs
.
last_hidden_state
,
last_hidden_state
=
decoder_outputs
.
last_hidden_state
,
past_key_values
=
past
,
past_key_values
=
past
,
...
@@ -1332,8 +1333,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
...
@@ -1332,8 +1333,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
loss
=
None
if
inputs
[
"labels"
]
is
None
else
self
.
compute_loss
(
inputs
[
"labels"
],
logits
)
loss
=
None
if
inputs
[
"labels"
]
is
None
else
self
.
compute_loss
(
inputs
[
"labels"
],
logits
)
past
=
(
inputs
[
"encoder_outputs"
],
decoder_outputs
[
1
])
if
inputs
[
"use_cache"
]
else
None
if
not
inputs
[
"return_dict"
]:
if
not
inputs
[
"return_dict"
]:
past
=
(
inputs
[
"encoder_outputs"
],
decoder_outputs
[
1
])
if
inputs
[
"use_cache"
]
else
None
if
past
is
not
None
:
if
past
is
not
None
:
decoder_outputs
=
decoder_outputs
[:
1
]
+
(
past
,)
+
decoder_outputs
[
2
:]
decoder_outputs
=
decoder_outputs
[:
1
]
+
(
past
,)
+
decoder_outputs
[
2
:]
output
=
(
logits
,)
+
decoder_outputs
[
1
:]
+
inputs
[
"encoder_outputs"
]
output
=
(
logits
,)
+
decoder_outputs
[
1
:]
+
inputs
[
"encoder_outputs"
]
...
@@ -1358,6 +1359,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
...
@@ -1358,6 +1359,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attentions
=
attentions
,
attentions
=
attentions
,
)
)
past
=
(
inputs
[
"encoder_outputs"
].
to_tuple
(),
decoder_outputs
[
1
])
if
inputs
[
"use_cache"
]
else
None
return
TFSeq2SeqLMOutput
(
return
TFSeq2SeqLMOutput
(
loss
=
loss
,
loss
=
loss
,
logits
=
logits
,
logits
=
logits
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment