Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e9d0f7f
Unverified
Commit
3e9d0f7f
authored
Mar 12, 2022
by
Omar Sanseviero
Committed by
GitHub
Mar 12, 2022
Browse files
Change unpacking of TF Bart inputs (#16094)
parent
580dd87c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
124 additions
and
222 deletions
+124
-222
src/transformers/models/bart/modeling_tf_bart.py
src/transformers/models/bart/modeling_tf_bart.py
+124
-222
No files found.
src/transformers/models/bart/modeling_tf_bart.py
View file @
3e9d0f7f
...
...
@@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFWrappedEmbeddings
,
input_processing
,
keras_serializable
,
unpack_inputs
,
)
from
...tf_utils
import
shape_list
from
...utils
import
logging
...
...
@@ -660,6 +660,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
def
set_embed_tokens
(
self
,
embed_tokens
):
self
.
embed_tokens
=
embed_tokens
@
unpack_inputs
def
call
(
self
,
input_ids
=
None
,
...
...
@@ -708,80 +709,67 @@ class TFBartEncoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs
[
"
inputs_embeds
"
]
is
None
:
inputs
[
"
inputs_embeds
"
]
=
self
.
embed_tokens
(
inputs
[
"
input_ids
"
]
)
*
self
.
embed_scale
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
embed_pos
=
self
.
embed_positions
(
input_shape
)
hidden_states
=
inputs
[
"
inputs_embeds
"
]
+
embed_pos
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
inputs
[
"
training
"
]
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
# check attention mask and invert
if
inputs
[
"
attention_mask
"
]
is
not
None
:
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask
=
_expand_mask
(
inputs
[
"
attention_mask
"
]
)
attention_mask
=
_expand_mask
(
attention_mask
)
else
:
attention_mask
=
None
encoder_states
=
()
if
inputs
[
"
output_hidden_states
"
]
else
None
all_attentions
=
()
if
inputs
[
"
output_attentions
"
]
else
None
encoder_states
=
()
if
output_hidden_states
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
inputs
[
"
head_mask
"
]
is
not
None
and
tf
.
executing_eagerly
():
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
"
head_mask
"
]
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
'
head_mask
'
]
)[
0
]
}
."
,
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
head_mask
)[
0
]
}
."
,
)
# encoder layers
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability
=
random
.
uniform
(
0
,
1
)
if
inputs
[
"
training
"
]
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
if
training
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
continue
hidden_states
,
attn
=
encoder_layer
(
hidden_states
,
attention_mask
,
inputs
[
"
head_mask
"
]
[
idx
]
if
inputs
[
"
head_mask
"
]
is
not
None
else
None
,
head_mask
[
idx
]
if
head_mask
is
not
None
else
None
,
)
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
all_attentions
+=
(
attn
,)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
TFBaseModelOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
...
...
@@ -822,6 +810,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
def
set_embed_tokens
(
self
,
embed_tokens
):
self
.
embed_tokens
=
embed_tokens
@
unpack_inputs
def
call
(
self
,
input_ids
=
None
,
...
...
@@ -899,45 +888,25 @@ class TFBartDecoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
head_mask
=
head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
inputs_embeds
=
inputs_embeds
,
past_key_values
=
past_key_values
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length
=
(
shape_list
(
inputs
[
"past_key_values"
][
0
][
0
])[
2
]
if
inputs
[
"past_key_values"
]
is
not
None
else
0
)
past_key_values_length
=
shape_list
(
past_key_values
[
0
][
0
])[
2
]
if
past_key_values
is
not
None
else
0
# embed positions
positions
=
self
.
embed_positions
(
input_shape
,
past_key_values_length
)
if
inputs
[
"
inputs_embeds
"
]
is
None
:
inputs
[
"
inputs_embeds
"
]
=
self
.
embed_tokens
(
inputs
[
"
input_ids
"
]
)
*
self
.
embed_scale
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
hidden_states
=
inputs
[
"
inputs_embeds
"
]
hidden_states
=
inputs_embeds
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if
input_shape
[
-
1
]
>
1
:
...
...
@@ -947,72 +916,68 @@ class TFBartDecoder(tf.keras.layers.Layer):
tf
.
ones
((
input_shape
[
0
],
input_shape
[
1
]
+
past_key_values_length
)),
tgt_len
=
input_shape
[
-
1
]
)
if
inputs
[
"attention_mask"
]
is
not
None
:
combined_attention_mask
=
combined_attention_mask
+
_expand_mask
(
inputs
[
"attention_mask"
],
tgt_len
=
input_shape
[
-
1
]
)
if
attention_mask
is
not
None
:
combined_attention_mask
=
combined_attention_mask
+
_expand_mask
(
attention_mask
,
tgt_len
=
input_shape
[
-
1
])
if
inputs
[
"
encoder_hidden_states
"
]
is
not
None
and
inputs
[
"
encoder_attention_mask
"
]
is
not
None
:
if
encoder_hidden_states
is
not
None
and
encoder_attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs
[
"
encoder_attention_mask
"
]
=
_expand_mask
(
inputs
[
"
encoder_attention_mask
"
]
,
tgt_len
=
input_shape
[
-
1
])
encoder_attention_mask
=
_expand_mask
(
encoder_attention_mask
,
tgt_len
=
input_shape
[
-
1
])
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
+
positions
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
inputs
[
"
training
"
]
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
# decoder layers
all_hidden_states
=
()
if
inputs
[
"
output_hidden_states
"
]
else
None
all_self_attns
=
()
if
inputs
[
"
output_attentions
"
]
else
None
all_cross_attns
=
()
if
(
inputs
[
"
output_attentions
"
]
and
inputs
[
"
encoder_hidden_states
"
]
is
not
None
)
else
None
present_key_values
=
()
if
inputs
[
"
use_cache
"
]
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_cross_attns
=
()
if
(
output_attentions
and
encoder_hidden_states
is
not
None
)
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask
in
[
"
head_mask
"
,
"
cross_attn_head_mask
"
]:
if
inputs
[
attn_mask
]
is
not
None
and
tf
.
executing_eagerly
():
for
attn_mask
in
[
head_mask
,
cross_attn_head_mask
]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
attn_mask
]
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
message
=
f
"The
{
attn_mask
}
should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
attn_mask
]
)[
0
]
}
."
,
message
=
f
"The
{
attn_mask
}
should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
attn_mask
)[
0
]
}
."
,
)
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
dropout_probability
=
random
.
uniform
(
0
,
1
)
if
inputs
[
"
training
"
]
and
(
dropout_probability
<
self
.
layerdrop
):
if
training
and
(
dropout_probability
<
self
.
layerdrop
):
continue
past_key_value
=
inputs
[
"
past_key_values
"
]
[
idx
]
if
inputs
[
"
past_key_values
"
]
is
not
None
else
None
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
hidden_states
,
layer_self_attn
,
layer_cross_attn
,
present_key_value
=
decoder_layer
(
hidden_states
,
attention_mask
=
combined_attention_mask
,
encoder_hidden_states
=
inputs
[
"encoder_hidden_states"
],
encoder_attention_mask
=
inputs
[
"encoder_attention_mask"
],
layer_head_mask
=
inputs
[
"head_mask"
][
idx
]
if
inputs
[
"head_mask"
]
is
not
None
else
None
,
cross_attn_layer_head_mask
=
inputs
[
"cross_attn_head_mask"
][
idx
]
if
inputs
[
"cross_attn_head_mask"
]
is
not
None
else
None
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
layer_head_mask
=
head_mask
[
idx
]
if
head_mask
is
not
None
else
None
,
cross_attn_layer_head_mask
=
cross_attn_head_mask
[
idx
]
if
cross_attn_head_mask
is
not
None
else
None
,
past_key_value
=
past_key_value
,
)
if
inputs
[
"
use_cache
"
]
:
if
use_cache
:
present_key_values
+=
(
present_key_value
,)
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
all_self_attns
+=
(
layer_self_attn
,)
if
inputs
[
"
encoder_hidden_states
"
]
is
not
None
:
if
encoder_hidden_states
is
not
None
:
all_cross_attns
+=
(
layer_cross_attn
,)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
hidden_states
,
present_key_values
,
all_hidden_states
,
all_self_attns
,
all_cross_attns
else
:
return
TFBaseModelOutputWithPastAndCrossAttentions
(
...
...
@@ -1062,6 +1027,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
self
.
encoder
.
set_embed_tokens
(
embed_tokens
)
self
.
decoder
.
set_embed_tokens
(
embed_tokens
)
@
unpack_inputs
def
call
(
self
,
input_ids
=
None
,
...
...
@@ -1082,82 +1048,59 @@ class TFBartMainLayer(tf.keras.layers.Layer):
training
=
False
,
**
kwargs
):
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
head_mask
=
head_mask
,
decoder_head_mask
=
decoder_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
encoder_outputs
=
encoder_outputs
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
decoder_inputs_embeds
=
decoder_inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"
decoder_input_ids
"
]
is
None
and
inputs
[
"
decoder_inputs_embeds
"
]
is
None
:
inputs
[
"
use_cache
"
]
=
False
if
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
use_cache
=
False
inputs
[
"output_hidden_states"
]
=
(
inputs
[
"output_hidden_states"
]
if
inputs
[
"output_hidden_states"
]
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
if
inputs
[
"
decoder_input_ids
"
]
is
None
and
inputs
[
"
input_ids
"
]
is
not
None
:
inputs
[
"
decoder_input_ids
"
]
=
shift_tokens_right
(
inputs
[
"
input_ids
"
]
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
if
decoder_input_ids
is
None
and
input_ids
is
not
None
:
decoder_input_ids
=
shift_tokens_right
(
input_ids
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
if
inputs
[
"
encoder_outputs
"
]
is
None
:
inputs
[
"
encoder_outputs
"
]
=
self
.
encoder
(
input_ids
=
inputs
[
"
input_ids
"
]
,
attention_mask
=
inputs
[
"
attention_mask
"
]
,
head_mask
=
inputs
[
"
head_mask
"
]
,
inputs_embeds
=
inputs
[
"inputs
_embeds
"
]
,
output_attentions
=
inputs
[
"
output_attentions
"
]
,
output_hidden_states
=
inputs
[
"
output_hidden_states
"
]
,
return_dict
=
inputs
[
"
return_dict
"
]
,
training
=
inputs
[
"
training
"
]
,
if
encoder_outputs
is
None
:
encoder_outputs
=
self
.
encoder
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif
inputs
[
"
return_dict
"
]
and
not
isinstance
(
inputs
[
"
encoder_outputs
"
]
,
TFBaseModelOutput
):
inputs
[
"
encoder_outputs
"
]
=
TFBaseModelOutput
(
last_hidden_state
=
inputs
[
"
encoder_outputs
"
]
[
0
],
hidden_states
=
inputs
[
"
encoder_outputs
"
]
[
1
]
if
len
(
inputs
[
"
encoder_outputs
"
]
)
>
1
else
None
,
attentions
=
inputs
[
"
encoder_outputs
"
]
[
2
]
if
len
(
inputs
[
"
encoder_outputs
"
]
)
>
2
else
None
,
elif
return_dict
and
not
isinstance
(
encoder_outputs
,
TFBaseModelOutput
):
encoder_outputs
=
TFBaseModelOutput
(
last_hidden_state
=
encoder_outputs
[
0
],
hidden_states
=
encoder_outputs
[
1
]
if
len
(
encoder_outputs
)
>
1
else
None
,
attentions
=
encoder_outputs
[
2
]
if
len
(
encoder_outputs
)
>
2
else
None
,
)
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif
not
inputs
[
"
return_dict
"
]
and
not
isinstance
(
inputs
[
"
encoder_outputs
"
]
,
tuple
):
inputs
[
"
encoder_outputs
"
]
=
inputs
[
"
encoder_outputs
"
]
.
to_tuple
()
elif
not
return_dict
and
not
isinstance
(
encoder_outputs
,
tuple
):
encoder_outputs
=
encoder_outputs
.
to_tuple
()
decoder_outputs
=
self
.
decoder
(
inputs
[
"
decoder_input_ids
"
]
,
attention_mask
=
inputs
[
"
decoder_attention_mask
"
]
,
encoder_hidden_states
=
inputs
[
"
encoder_outputs
"
]
[
0
],
encoder_attention_mask
=
inputs
[
"
attention_mask
"
]
,
head_mask
=
inputs
[
"
decoder_head_mask
"
]
,
cross_attn_head_mask
=
inputs
[
"
cross_attn_head_mask
"
]
,
past_key_values
=
inputs
[
"
past_key_values
"
]
,
inputs_embeds
=
inputs
[
"
decoder_inputs_embeds
"
]
,
use_cache
=
inputs
[
"
use_cache
"
]
,
output_attentions
=
inputs
[
"
output_attentions
"
]
,
output_hidden_states
=
inputs
[
"
output_hidden_states
"
]
,
return_dict
=
inputs
[
"
return_dict
"
]
,
training
=
inputs
[
"
training
"
]
,
decoder_input_ids
,
attention_mask
=
decoder_attention_mask
,
encoder_hidden_states
=
encoder_outputs
[
0
],
encoder_attention_mask
=
attention_mask
,
head_mask
=
decoder_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
past_key_values
=
past_key_values
,
inputs_embeds
=
decoder_inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
)
if
not
inputs
[
"
return_dict
"
]
:
return
decoder_outputs
+
inputs
[
"
encoder_outputs
"
]
if
not
return_dict
:
return
decoder_outputs
+
encoder_outputs
return
TFSeq2SeqModelOutput
(
last_hidden_state
=
decoder_outputs
.
last_hidden_state
,
...
...
@@ -1165,9 +1108,9 @@ class TFBartMainLayer(tf.keras.layers.Layer):
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
encoder_last_hidden_state
=
inputs
[
"
encoder_outputs
"
]
.
last_hidden_state
,
encoder_hidden_states
=
inputs
[
"
encoder_outputs
"
]
.
hidden_states
,
encoder_attentions
=
inputs
[
"
encoder_outputs
"
]
.
attentions
,
encoder_last_hidden_state
=
encoder_outputs
.
last_hidden_state
,
encoder_hidden_states
=
encoder_outputs
.
hidden_states
,
encoder_attentions
=
encoder_outputs
.
attentions
,
)
...
...
@@ -1197,6 +1140,7 @@ class TFBartModel(TFBartPretrainedModel):
output_type
=
TFSeq2SeqModelOutput
,
config_class
=
_CONFIG_FOR_DOC
,
)
@
unpack_inputs
def
call
(
self
,
input_ids
=
None
,
...
...
@@ -1217,9 +1161,8 @@ class TFBartModel(TFBartPretrainedModel):
training
=
False
,
**
kwargs
):
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
outputs
=
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
...
...
@@ -1236,26 +1179,6 @@ class TFBartModel(TFBartPretrainedModel):
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
decoder_input_ids
=
inputs
[
"decoder_input_ids"
],
decoder_attention_mask
=
inputs
[
"decoder_attention_mask"
],
head_mask
=
inputs
[
"head_mask"
],
decoder_head_mask
=
inputs
[
"decoder_head_mask"
],
cross_attn_head_mask
=
inputs
[
"cross_attn_head_mask"
],
encoder_outputs
=
inputs
[
"encoder_outputs"
],
past_key_values
=
inputs
[
"past_key_values"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
decoder_inputs_embeds
=
inputs
[
"decoder_inputs_embeds"
],
use_cache
=
inputs
[
"use_cache"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
return
outputs
...
...
@@ -1322,6 +1245,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
@
add_start_docstrings_to_model_forward
(
BART_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFSeq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
add_end_docstrings
(
BART_GENERATION_EXAMPLE
)
@
unpack_inputs
def
call
(
self
,
input_ids
=
None
,
...
...
@@ -1352,17 +1276,28 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
Returns:
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
if
labels
is
not
None
:
labels
=
tf
.
where
(
labels
==
self
.
config
.
pad_token_id
,
tf
.
cast
(
tf
.
fill
(
shape_list
(
labels
),
-
100
),
labels
.
dtype
),
labels
,
)
use_cache
=
False
if
decoder_input_ids
is
None
:
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
encoder_outputs
=
encoder_outputs
,
decoder_attention_mask
=
decoder_attention_mask
,
head_mask
=
head_mask
,
decoder_head_mask
=
decoder_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
encoder_outputs
=
encoder_outputs
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
decoder_inputs_embeds
=
decoder_inputs_embeds
,
...
...
@@ -1370,46 +1305,13 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"labels"
]
is
not
None
:
inputs
[
"labels"
]
=
tf
.
where
(
inputs
[
"labels"
]
==
self
.
config
.
pad_token_id
,
tf
.
cast
(
tf
.
fill
(
shape_list
(
inputs
[
"labels"
]),
-
100
),
inputs
[
"labels"
].
dtype
),
inputs
[
"labels"
],
)
inputs
[
"use_cache"
]
=
False
if
inputs
[
"decoder_input_ids"
]
is
None
:
inputs
[
"decoder_input_ids"
]
=
shift_tokens_right
(
inputs
[
"labels"
],
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
outputs
=
self
.
model
(
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
decoder_input_ids
=
inputs
[
"decoder_input_ids"
],
encoder_outputs
=
inputs
[
"encoder_outputs"
],
decoder_attention_mask
=
inputs
[
"decoder_attention_mask"
],
head_mask
=
inputs
[
"head_mask"
],
decoder_head_mask
=
inputs
[
"decoder_head_mask"
],
cross_attn_head_mask
=
inputs
[
"cross_attn_head_mask"
],
past_key_values
=
inputs
[
"past_key_values"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
decoder_inputs_embeds
=
inputs
[
"decoder_inputs_embeds"
],
use_cache
=
inputs
[
"use_cache"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final_logits_bias
masked_lm_loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
lm_logits
,)
+
outputs
[
1
:]
return
((
masked_lm_loss
,)
+
output
)
if
masked_lm_loss
is
not
None
else
output
return
TFSeq2SeqLMOutput
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment