Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e9d0f7f
"docs/source/ja/index.mdx" did not exist on "61a51f5f23d7ce6b8acf61b5aa170e01d7658d74"
Unverified
Commit
3e9d0f7f
authored
Mar 12, 2022
by
Omar Sanseviero
Committed by
GitHub
Mar 12, 2022
Browse files
Change unpacking of TF Bart inputs (#16094)
parent
580dd87c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
124 additions
and
222 deletions
+124
-222
src/transformers/models/bart/modeling_tf_bart.py
src/transformers/models/bart/modeling_tf_bart.py
+124
-222
No files found.
src/transformers/models/bart/modeling_tf_bart.py
View file @
3e9d0f7f
...
@@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
...
@@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel
,
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSharedEmbeddings
,
TFWrappedEmbeddings
,
TFWrappedEmbeddings
,
input_processing
,
keras_serializable
,
keras_serializable
,
unpack_inputs
,
)
)
from
...tf_utils
import
shape_list
from
...tf_utils
import
shape_list
from
...utils
import
logging
from
...utils
import
logging
...
@@ -660,6 +660,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
...
@@ -660,6 +660,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
def
set_embed_tokens
(
self
,
embed_tokens
):
def
set_embed_tokens
(
self
,
embed_tokens
):
self
.
embed_tokens
=
embed_tokens
self
.
embed_tokens
=
embed_tokens
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -708,80 +709,67 @@ class TFBartEncoder(tf.keras.layers.Layer):
...
@@ -708,80 +709,67 @@ class TFBartEncoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*):
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
input_shape
=
shape_list
(
input_ids
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs
[
"
inputs_embeds
"
]
is
None
:
if
inputs_embeds
is
None
:
inputs
[
"
inputs_embeds
"
]
=
self
.
embed_tokens
(
inputs
[
"
input_ids
"
]
)
*
self
.
embed_scale
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
embed_pos
=
self
.
embed_positions
(
input_shape
)
embed_pos
=
self
.
embed_positions
(
input_shape
)
hidden_states
=
inputs
[
"
inputs_embeds
"
]
+
embed_pos
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
inputs
[
"
training
"
]
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
# check attention mask and invert
# check attention mask and invert
if
inputs
[
"
attention_mask
"
]
is
not
None
:
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask
=
_expand_mask
(
inputs
[
"
attention_mask
"
]
)
attention_mask
=
_expand_mask
(
attention_mask
)
else
:
else
:
attention_mask
=
None
attention_mask
=
None
encoder_states
=
()
if
inputs
[
"
output_hidden_states
"
]
else
None
encoder_states
=
()
if
output_hidden_states
else
None
all_attentions
=
()
if
inputs
[
"
output_attentions
"
]
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
# have to be disabled in other modes than eager.
if
inputs
[
"
head_mask
"
]
is
not
None
and
tf
.
executing_eagerly
():
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
"
head_mask
"
]
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
'
head_mask
'
]
)[
0
]
}
."
,
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
head_mask
)[
0
]
}
."
,
)
)
# encoder layers
# encoder layers
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
encoder_states
=
encoder_states
+
(
hidden_states
,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability
=
random
.
uniform
(
0
,
1
)
dropout_probability
=
random
.
uniform
(
0
,
1
)
if
inputs
[
"
training
"
]
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
if
training
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
continue
continue
hidden_states
,
attn
=
encoder_layer
(
hidden_states
,
attn
=
encoder_layer
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
inputs
[
"
head_mask
"
]
[
idx
]
if
inputs
[
"
head_mask
"
]
is
not
None
else
None
,
head_mask
[
idx
]
if
head_mask
is
not
None
else
None
,
)
)
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
all_attentions
+=
(
attn
,)
all_attentions
+=
(
attn
,)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
TFBaseModelOutput
(
return
TFBaseModelOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
...
@@ -822,6 +810,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
...
@@ -822,6 +810,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
def
set_embed_tokens
(
self
,
embed_tokens
):
def
set_embed_tokens
(
self
,
embed_tokens
):
self
.
embed_tokens
=
embed_tokens
self
.
embed_tokens
=
embed_tokens
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -899,45 +888,25 @@ class TFBartDecoder(tf.keras.layers.Layer):
...
@@ -899,45 +888,25 @@ class TFBartDecoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*):
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
head_mask
=
head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
inputs_embeds
=
inputs_embeds
,
past_key_values
=
past_key_values
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
raise
ValueError
(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
input_shape
=
shape_list
(
input_ids
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length
=
(
past_key_values_length
=
shape_list
(
past_key_values
[
0
][
0
])[
2
]
if
past_key_values
is
not
None
else
0
shape_list
(
inputs
[
"past_key_values"
][
0
][
0
])[
2
]
if
inputs
[
"past_key_values"
]
is
not
None
else
0
)
# embed positions
# embed positions
positions
=
self
.
embed_positions
(
input_shape
,
past_key_values_length
)
positions
=
self
.
embed_positions
(
input_shape
,
past_key_values_length
)
if
inputs
[
"
inputs_embeds
"
]
is
None
:
if
inputs_embeds
is
None
:
inputs
[
"
inputs_embeds
"
]
=
self
.
embed_tokens
(
inputs
[
"
input_ids
"
]
)
*
self
.
embed_scale
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
hidden_states
=
inputs
[
"
inputs_embeds
"
]
hidden_states
=
inputs_embeds
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if
input_shape
[
-
1
]
>
1
:
if
input_shape
[
-
1
]
>
1
:
...
@@ -947,72 +916,68 @@ class TFBartDecoder(tf.keras.layers.Layer):
...
@@ -947,72 +916,68 @@ class TFBartDecoder(tf.keras.layers.Layer):
tf
.
ones
((
input_shape
[
0
],
input_shape
[
1
]
+
past_key_values_length
)),
tgt_len
=
input_shape
[
-
1
]
tf
.
ones
((
input_shape
[
0
],
input_shape
[
1
]
+
past_key_values_length
)),
tgt_len
=
input_shape
[
-
1
]
)
)
if
inputs
[
"attention_mask"
]
is
not
None
:
if
attention_mask
is
not
None
:
combined_attention_mask
=
combined_attention_mask
+
_expand_mask
(
combined_attention_mask
=
combined_attention_mask
+
_expand_mask
(
attention_mask
,
tgt_len
=
input_shape
[
-
1
])
inputs
[
"attention_mask"
],
tgt_len
=
input_shape
[
-
1
]
)
if
inputs
[
"
encoder_hidden_states
"
]
is
not
None
and
inputs
[
"
encoder_attention_mask
"
]
is
not
None
:
if
encoder_hidden_states
is
not
None
and
encoder_attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs
[
"
encoder_attention_mask
"
]
=
_expand_mask
(
inputs
[
"
encoder_attention_mask
"
]
,
tgt_len
=
input_shape
[
-
1
])
encoder_attention_mask
=
_expand_mask
(
encoder_attention_mask
,
tgt_len
=
input_shape
[
-
1
])
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
+
positions
)
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
+
positions
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
inputs
[
"
training
"
]
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
# decoder layers
# decoder layers
all_hidden_states
=
()
if
inputs
[
"
output_hidden_states
"
]
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
inputs
[
"
output_attentions
"
]
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_cross_attns
=
()
if
(
inputs
[
"
output_attentions
"
]
and
inputs
[
"
encoder_hidden_states
"
]
is
not
None
)
else
None
all_cross_attns
=
()
if
(
output_attentions
and
encoder_hidden_states
is
not
None
)
else
None
present_key_values
=
()
if
inputs
[
"
use_cache
"
]
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
# have to be disabled in other modes than eager.
for
attn_mask
in
[
"
head_mask
"
,
"
cross_attn_head_mask
"
]:
for
attn_mask
in
[
head_mask
,
cross_attn_head_mask
]:
if
inputs
[
attn_mask
]
is
not
None
and
tf
.
executing_eagerly
():
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
attn_mask
]
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
message
=
f
"The
{
attn_mask
}
should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
attn_mask
]
)[
0
]
}
."
,
message
=
f
"The
{
attn_mask
}
should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
attn_mask
)[
0
]
}
."
,
)
)
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
dropout_probability
=
random
.
uniform
(
0
,
1
)
dropout_probability
=
random
.
uniform
(
0
,
1
)
if
inputs
[
"
training
"
]
and
(
dropout_probability
<
self
.
layerdrop
):
if
training
and
(
dropout_probability
<
self
.
layerdrop
):
continue
continue
past_key_value
=
inputs
[
"
past_key_values
"
]
[
idx
]
if
inputs
[
"
past_key_values
"
]
is
not
None
else
None
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
hidden_states
,
layer_self_attn
,
layer_cross_attn
,
present_key_value
=
decoder_layer
(
hidden_states
,
layer_self_attn
,
layer_cross_attn
,
present_key_value
=
decoder_layer
(
hidden_states
,
hidden_states
,
attention_mask
=
combined_attention_mask
,
attention_mask
=
combined_attention_mask
,
encoder_hidden_states
=
inputs
[
"encoder_hidden_states"
],
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
inputs
[
"encoder_attention_mask"
],
encoder_attention_mask
=
encoder_attention_mask
,
layer_head_mask
=
inputs
[
"head_mask"
][
idx
]
if
inputs
[
"head_mask"
]
is
not
None
else
None
,
layer_head_mask
=
head_mask
[
idx
]
if
head_mask
is
not
None
else
None
,
cross_attn_layer_head_mask
=
inputs
[
"cross_attn_head_mask"
][
idx
]
cross_attn_layer_head_mask
=
cross_attn_head_mask
[
idx
]
if
cross_attn_head_mask
is
not
None
else
None
,
if
inputs
[
"cross_attn_head_mask"
]
is
not
None
else
None
,
past_key_value
=
past_key_value
,
past_key_value
=
past_key_value
,
)
)
if
inputs
[
"
use_cache
"
]
:
if
use_cache
:
present_key_values
+=
(
present_key_value
,)
present_key_values
+=
(
present_key_value
,)
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
all_self_attns
+=
(
layer_self_attn
,)
all_self_attns
+=
(
layer_self_attn
,)
if
inputs
[
"
encoder_hidden_states
"
]
is
not
None
:
if
encoder_hidden_states
is
not
None
:
all_cross_attns
+=
(
layer_cross_attn
,)
all_cross_attns
+=
(
layer_cross_attn
,)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
hidden_states
,
present_key_values
,
all_hidden_states
,
all_self_attns
,
all_cross_attns
return
hidden_states
,
present_key_values
,
all_hidden_states
,
all_self_attns
,
all_cross_attns
else
:
else
:
return
TFBaseModelOutputWithPastAndCrossAttentions
(
return
TFBaseModelOutputWithPastAndCrossAttentions
(
...
@@ -1062,6 +1027,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
...
@@ -1062,6 +1027,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
self
.
encoder
.
set_embed_tokens
(
embed_tokens
)
self
.
encoder
.
set_embed_tokens
(
embed_tokens
)
self
.
decoder
.
set_embed_tokens
(
embed_tokens
)
self
.
decoder
.
set_embed_tokens
(
embed_tokens
)
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -1082,82 +1048,59 @@ class TFBartMainLayer(tf.keras.layers.Layer):
...
@@ -1082,82 +1048,59 @@ class TFBartMainLayer(tf.keras.layers.Layer):
training
=
False
,
training
=
False
,
**
kwargs
**
kwargs
):
):
inputs
=
input_processing
(
func
=
self
.
call
,
if
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
config
=
self
.
config
,
use_cache
=
False
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
if
decoder_input_ids
is
None
and
input_ids
is
not
None
:
decoder_input_ids
=
shift_tokens_right
(
input_ids
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
if
encoder_outputs
is
None
:
encoder_outputs
=
self
.
encoder
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
decoder_head_mask
=
decoder_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
encoder_outputs
=
encoder_outputs
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
decoder_inputs_embeds
=
decoder_inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"decoder_input_ids"
]
is
None
and
inputs
[
"decoder_inputs_embeds"
]
is
None
:
inputs
[
"use_cache"
]
=
False
inputs
[
"output_hidden_states"
]
=
(
inputs
[
"output_hidden_states"
]
if
inputs
[
"output_hidden_states"
]
is
not
None
else
self
.
config
.
output_hidden_states
)
if
inputs
[
"decoder_input_ids"
]
is
None
and
inputs
[
"input_ids"
]
is
not
None
:
inputs
[
"decoder_input_ids"
]
=
shift_tokens_right
(
inputs
[
"input_ids"
],
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
if
inputs
[
"encoder_outputs"
]
is
None
:
inputs
[
"encoder_outputs"
]
=
self
.
encoder
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif
inputs
[
"
return_dict
"
]
and
not
isinstance
(
inputs
[
"
encoder_outputs
"
]
,
TFBaseModelOutput
):
elif
return_dict
and
not
isinstance
(
encoder_outputs
,
TFBaseModelOutput
):
inputs
[
"
encoder_outputs
"
]
=
TFBaseModelOutput
(
encoder_outputs
=
TFBaseModelOutput
(
last_hidden_state
=
inputs
[
"
encoder_outputs
"
]
[
0
],
last_hidden_state
=
encoder_outputs
[
0
],
hidden_states
=
inputs
[
"
encoder_outputs
"
]
[
1
]
if
len
(
inputs
[
"
encoder_outputs
"
]
)
>
1
else
None
,
hidden_states
=
encoder_outputs
[
1
]
if
len
(
encoder_outputs
)
>
1
else
None
,
attentions
=
inputs
[
"
encoder_outputs
"
]
[
2
]
if
len
(
inputs
[
"
encoder_outputs
"
]
)
>
2
else
None
,
attentions
=
encoder_outputs
[
2
]
if
len
(
encoder_outputs
)
>
2
else
None
,
)
)
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif
not
inputs
[
"
return_dict
"
]
and
not
isinstance
(
inputs
[
"
encoder_outputs
"
]
,
tuple
):
elif
not
return_dict
and
not
isinstance
(
encoder_outputs
,
tuple
):
inputs
[
"
encoder_outputs
"
]
=
inputs
[
"
encoder_outputs
"
]
.
to_tuple
()
encoder_outputs
=
encoder_outputs
.
to_tuple
()
decoder_outputs
=
self
.
decoder
(
decoder_outputs
=
self
.
decoder
(
inputs
[
"
decoder_input_ids
"
]
,
decoder_input_ids
,
attention_mask
=
inputs
[
"
decoder_attention_mask
"
]
,
attention_mask
=
decoder_attention_mask
,
encoder_hidden_states
=
inputs
[
"
encoder_outputs
"
]
[
0
],
encoder_hidden_states
=
encoder_outputs
[
0
],
encoder_attention_mask
=
inputs
[
"
attention_mask
"
]
,
encoder_attention_mask
=
attention_mask
,
head_mask
=
inputs
[
"
decoder_head_mask
"
]
,
head_mask
=
decoder_head_mask
,
cross_attn_head_mask
=
inputs
[
"
cross_attn_head_mask
"
]
,
cross_attn_head_mask
=
cross_attn_head_mask
,
past_key_values
=
inputs
[
"
past_key_values
"
]
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs
[
"
decoder_inputs_embeds
"
]
,
inputs_embeds
=
decoder_inputs_embeds
,
use_cache
=
inputs
[
"
use_cache
"
]
,
use_cache
=
use_cache
,
output_attentions
=
inputs
[
"
output_attentions
"
]
,
output_attentions
=
output_attentions
,
output_hidden_states
=
inputs
[
"
output_hidden_states
"
]
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
inputs
[
"
return_dict
"
]
,
return_dict
=
return_dict
,
training
=
inputs
[
"
training
"
]
,
training
=
training
,
)
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
decoder_outputs
+
inputs
[
"
encoder_outputs
"
]
return
decoder_outputs
+
encoder_outputs
return
TFSeq2SeqModelOutput
(
return
TFSeq2SeqModelOutput
(
last_hidden_state
=
decoder_outputs
.
last_hidden_state
,
last_hidden_state
=
decoder_outputs
.
last_hidden_state
,
...
@@ -1165,9 +1108,9 @@ class TFBartMainLayer(tf.keras.layers.Layer):
...
@@ -1165,9 +1108,9 @@ class TFBartMainLayer(tf.keras.layers.Layer):
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
encoder_last_hidden_state
=
inputs
[
"
encoder_outputs
"
]
.
last_hidden_state
,
encoder_last_hidden_state
=
encoder_outputs
.
last_hidden_state
,
encoder_hidden_states
=
inputs
[
"
encoder_outputs
"
]
.
hidden_states
,
encoder_hidden_states
=
encoder_outputs
.
hidden_states
,
encoder_attentions
=
inputs
[
"
encoder_outputs
"
]
.
attentions
,
encoder_attentions
=
encoder_outputs
.
attentions
,
)
)
...
@@ -1197,6 +1140,7 @@ class TFBartModel(TFBartPretrainedModel):
...
@@ -1197,6 +1140,7 @@ class TFBartModel(TFBartPretrainedModel):
output_type
=
TFSeq2SeqModelOutput
,
output_type
=
TFSeq2SeqModelOutput
,
config_class
=
_CONFIG_FOR_DOC
,
config_class
=
_CONFIG_FOR_DOC
,
)
)
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -1217,9 +1161,8 @@ class TFBartModel(TFBartPretrainedModel):
...
@@ -1217,9 +1161,8 @@ class TFBartModel(TFBartPretrainedModel):
training
=
False
,
training
=
False
,
**
kwargs
**
kwargs
):
):
inputs
=
input_processing
(
func
=
self
.
call
,
outputs
=
self
.
model
(
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
...
@@ -1236,26 +1179,6 @@ class TFBartModel(TFBartPretrainedModel):
...
@@ -1236,26 +1179,6 @@ class TFBartModel(TFBartPretrainedModel):
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
decoder_input_ids
=
inputs
[
"decoder_input_ids"
],
decoder_attention_mask
=
inputs
[
"decoder_attention_mask"
],
head_mask
=
inputs
[
"head_mask"
],
decoder_head_mask
=
inputs
[
"decoder_head_mask"
],
cross_attn_head_mask
=
inputs
[
"cross_attn_head_mask"
],
encoder_outputs
=
inputs
[
"encoder_outputs"
],
past_key_values
=
inputs
[
"past_key_values"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
decoder_inputs_embeds
=
inputs
[
"decoder_inputs_embeds"
],
use_cache
=
inputs
[
"use_cache"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
return
outputs
return
outputs
...
@@ -1322,6 +1245,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
...
@@ -1322,6 +1245,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
@
add_start_docstrings_to_model_forward
(
BART_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
BART_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFSeq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
TFSeq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
add_end_docstrings
(
BART_GENERATION_EXAMPLE
)
@
add_end_docstrings
(
BART_GENERATION_EXAMPLE
)
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -1352,17 +1276,28 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
...
@@ -1352,17 +1276,28 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
Returns:
Returns:
"""
"""
inputs
=
input_processing
(
func
=
self
.
call
,
if
labels
is
not
None
:
config
=
self
.
config
,
labels
=
tf
.
where
(
input_ids
=
input_ids
,
labels
==
self
.
config
.
pad_token_id
,
tf
.
cast
(
tf
.
fill
(
shape_list
(
labels
),
-
100
),
labels
.
dtype
),
labels
,
)
use_cache
=
False
if
decoder_input_ids
is
None
:
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
encoder_outputs
=
encoder_outputs
,
decoder_attention_mask
=
decoder_attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
decoder_head_mask
=
decoder_head_mask
,
decoder_head_mask
=
decoder_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
encoder_outputs
=
encoder_outputs
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
decoder_inputs_embeds
=
decoder_inputs_embeds
,
decoder_inputs_embeds
=
decoder_inputs_embeds
,
...
@@ -1370,46 +1305,13 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
...
@@ -1370,46 +1305,13 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"labels"
]
is
not
None
:
inputs
[
"labels"
]
=
tf
.
where
(
inputs
[
"labels"
]
==
self
.
config
.
pad_token_id
,
tf
.
cast
(
tf
.
fill
(
shape_list
(
inputs
[
"labels"
]),
-
100
),
inputs
[
"labels"
].
dtype
),
inputs
[
"labels"
],
)
inputs
[
"use_cache"
]
=
False
if
inputs
[
"decoder_input_ids"
]
is
None
:
inputs
[
"decoder_input_ids"
]
=
shift_tokens_right
(
inputs
[
"labels"
],
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
outputs
=
self
.
model
(
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
decoder_input_ids
=
inputs
[
"decoder_input_ids"
],
encoder_outputs
=
inputs
[
"encoder_outputs"
],
decoder_attention_mask
=
inputs
[
"decoder_attention_mask"
],
head_mask
=
inputs
[
"head_mask"
],
decoder_head_mask
=
inputs
[
"decoder_head_mask"
],
cross_attn_head_mask
=
inputs
[
"cross_attn_head_mask"
],
past_key_values
=
inputs
[
"past_key_values"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
decoder_inputs_embeds
=
inputs
[
"decoder_inputs_embeds"
],
use_cache
=
inputs
[
"use_cache"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final_logits_bias
lm_logits
=
lm_logits
+
self
.
final_logits_bias
masked_lm_loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
lm_logits
,)
+
outputs
[
1
:]
output
=
(
lm_logits
,)
+
outputs
[
1
:]
return
((
masked_lm_loss
,)
+
output
)
if
masked_lm_loss
is
not
None
else
output
return
((
masked_lm_loss
,)
+
output
)
if
masked_lm_loss
is
not
None
else
output
return
TFSeq2SeqLMOutput
(
return
TFSeq2SeqLMOutput
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment