Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9fef6683
Unverified
Commit
9fef6683
authored
Mar 21, 2022
by
Joao Gante
Committed by
GitHub
Mar 21, 2022
Browse files
TF - update (vision_)encoder_decoder past variable (#16260)
parent
f9387c94
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
8 deletions
+4
-8
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
...ers/models/encoder_decoder/modeling_tf_encoder_decoder.py
+2
-4
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
...ion_encoder_decoder/modeling_tf_vision_encoder_decoder.py
+2
-4
No files found.
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
View file @
9fef6683
...
...
@@ -647,19 +647,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
# The starting index of the remaining elements in `decoder_outputs`
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
past
=
(
encoder_outputs
[
0
],
past_key_values
)
if
past_key_values
else
None
if
not
decoder_inputs
[
"return_dict"
]:
if
not
isinstance
(
encoder_outputs
,
tuple
):
encoder_outputs
=
encoder_outputs
.
to_tuple
()
output
=
(
loss
,
logits
,
past
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
output
=
(
loss
,
logits
,
past
_key_values
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
output
=
tuple
([
x
for
x
in
output
if
x
is
not
None
])
return
output
return
TFSeq2SeqLMOutput
(
loss
=
loss
,
logits
=
decoder_outputs
.
logits
,
past_key_values
=
past
,
past_key_values
=
past
_key_values
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
...
...
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
View file @
9fef6683
...
...
@@ -678,19 +678,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
# The starting index of the remaining elements in `decoder_outputs`
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
past
=
(
encoder_outputs
[
0
],
past_key_values
)
if
past_key_values
else
None
if
not
decoder_inputs
[
"return_dict"
]:
if
not
isinstance
(
encoder_outputs
,
tuple
):
encoder_outputs
=
encoder_outputs
.
to_tuple
()
output
=
(
loss
,
logits
,
past
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
output
=
(
loss
,
logits
,
past
_key_values
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
output
=
tuple
([
x
for
x
in
output
if
x
is
not
None
])
return
output
return
TFSeq2SeqLMOutput
(
loss
=
loss
,
logits
=
decoder_outputs
.
logits
,
past_key_values
=
past
,
past_key_values
=
past
_key_values
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment