Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0d7322c1
"docs/vscode:/vscode.git/clone" did not exist on "50f5266b2cbe1c0821d4e32e0c3eb0723dedaa28"
Unverified
Commit
0d7322c1
authored
Mar 15, 2022
by
Kamal Raj
Committed by
GitHub
Mar 15, 2022
Browse files
TF clearer model variable naming: pegasus (#16152)
parent
cd4c5c90
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
122 additions
and
220 deletions
+122
-220
src/transformers/models/pegasus/modeling_tf_pegasus.py
src/transformers/models/pegasus/modeling_tf_pegasus.py
+122
-220
No files found.
src/transformers/models/pegasus/modeling_tf_pegasus.py
View file @
0d7322c1
...
@@ -43,8 +43,8 @@ from ...modeling_tf_utils import (
...
@@ -43,8 +43,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel
,
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSharedEmbeddings
,
TFWrappedEmbeddings
,
TFWrappedEmbeddings
,
input_processing
,
keras_serializable
,
keras_serializable
,
unpack_inputs
,
)
)
from
...tf_utils
import
shape_list
from
...tf_utils
import
shape_list
from
...utils
import
logging
from
...utils
import
logging
...
@@ -575,7 +575,7 @@ PEGASUS_GENERATION_EXAMPLE = r"""
...
@@ -575,7 +575,7 @@ PEGASUS_GENERATION_EXAMPLE = r"""
>>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="tf")
>>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="tf")
>>> # Generate Summary
>>> # Generate Summary
>>> summary_ids = model.generate(
inputs["
input_ids
"]
)
>>> summary_ids = model.generate(input_ids)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
```
```
"""
"""
...
@@ -693,6 +693,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
...
@@ -693,6 +693,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
def
set_embed_tokens
(
self
,
embed_tokens
):
def
set_embed_tokens
(
self
,
embed_tokens
):
self
.
embed_tokens
=
embed_tokens
self
.
embed_tokens
=
embed_tokens
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -747,81 +748,68 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
...
@@ -747,81 +748,68 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
behaviors between training and evaluation).
"""
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
input_shape
=
shape_list
(
input_ids
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs
[
"
inputs_embeds
"
]
is
None
:
if
inputs_embeds
is
None
:
inputs
[
"
inputs_embeds
"
]
=
self
.
embed_tokens
(
inputs
[
"
input_ids
"
]
)
*
self
.
embed_scale
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
embed_pos
=
self
.
embed_positions
(
input_shape
)
embed_pos
=
self
.
embed_positions
(
input_shape
)
hidden_states
=
inputs
[
"
inputs_embeds
"
]
+
embed_pos
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
inputs
[
"
training
"
]
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
# check attention mask and invert
# check attention mask and invert
if
inputs
[
"
attention_mask
"
]
is
not
None
:
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask
=
_expand_mask
(
inputs
[
"
attention_mask
"
]
)
attention_mask
=
_expand_mask
(
attention_mask
)
else
:
else
:
attention_mask
=
None
attention_mask
=
None
encoder_states
=
()
if
inputs
[
"
output_hidden_states
"
]
else
None
encoder_states
=
()
if
output_hidden_states
else
None
all_attentions
=
()
if
inputs
[
"
output_attentions
"
]
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
# have to be disabled in other modes than eager.
if
inputs
[
"
head_mask
"
]
is
not
None
and
tf
.
executing_eagerly
():
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
"
head_mask
"
]
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
'
head_mask
'
]
)[
0
]
}
."
,
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
head_mask
)[
0
]
}
."
,
)
)
# encoder layers
# encoder layers
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
encoder_states
=
encoder_states
+
(
hidden_states
,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability
=
random
.
uniform
(
0
,
1
)
dropout_probability
=
random
.
uniform
(
0
,
1
)
if
inputs
[
"
training
"
]
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
if
training
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
continue
continue
hidden_states
,
attn
=
encoder_layer
(
hidden_states
,
attn
=
encoder_layer
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
inputs
[
"
head_mask
"
]
[
idx
]
if
inputs
[
"
head_mask
"
]
is
not
None
else
None
,
head_mask
[
idx
]
if
head_mask
is
not
None
else
None
,
)
)
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
all_attentions
+=
(
attn
,)
all_attentions
+=
(
attn
,)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
TFBaseModelOutput
(
return
TFBaseModelOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
...
@@ -862,6 +850,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
...
@@ -862,6 +850,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
def
set_embed_tokens
(
self
,
embed_tokens
):
def
set_embed_tokens
(
self
,
embed_tokens
):
self
.
embed_tokens
=
embed_tokens
self
.
embed_tokens
=
embed_tokens
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -945,45 +934,25 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
...
@@ -945,45 +934,25 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
behaviors between training and evaluation).
"""
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
head_mask
=
head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
inputs_embeds
=
inputs_embeds
,
past_key_values
=
past_key_values
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
raise
ValueError
(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
input_shape
=
shape_list
(
input_ids
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length
=
(
past_key_values_length
=
shape_list
(
past_key_values
[
0
][
0
])[
2
]
if
past_key_values
is
not
None
else
0
shape_list
(
inputs
[
"past_key_values"
][
0
][
0
])[
2
]
if
inputs
[
"past_key_values"
]
is
not
None
else
0
)
# embed positions
# embed positions
positions
=
self
.
embed_positions
(
input_shape
,
past_key_values_length
)
positions
=
self
.
embed_positions
(
input_shape
,
past_key_values_length
)
if
inputs
[
"
inputs_embeds
"
]
is
None
:
if
inputs_embeds
is
None
:
inputs
[
"
inputs_embeds
"
]
=
self
.
embed_tokens
(
inputs
[
"
input_ids
"
]
)
*
self
.
embed_scale
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
hidden_states
=
inputs
[
"
inputs_embeds
"
]
hidden_states
=
inputs_embeds
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if
input_shape
[
-
1
]
>
1
:
if
input_shape
[
-
1
]
>
1
:
...
@@ -993,72 +962,68 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
...
@@ -993,72 +962,68 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
tf
.
ones
((
input_shape
[
0
],
input_shape
[
1
]
+
past_key_values_length
)),
tgt_len
=
input_shape
[
-
1
]
tf
.
ones
((
input_shape
[
0
],
input_shape
[
1
]
+
past_key_values_length
)),
tgt_len
=
input_shape
[
-
1
]
)
)
if
inputs
[
"attention_mask"
]
is
not
None
:
if
attention_mask
is
not
None
:
combined_attention_mask
=
combined_attention_mask
+
_expand_mask
(
combined_attention_mask
=
combined_attention_mask
+
_expand_mask
(
attention_mask
,
tgt_len
=
input_shape
[
-
1
])
inputs
[
"attention_mask"
],
tgt_len
=
input_shape
[
-
1
]
)
if
inputs
[
"
encoder_hidden_states
"
]
is
not
None
and
inputs
[
"
encoder_attention_mask
"
]
is
not
None
:
if
encoder_hidden_states
is
not
None
and
encoder_attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs
[
"
encoder_attention_mask
"
]
=
_expand_mask
(
inputs
[
"
encoder_attention_mask
"
]
,
tgt_len
=
input_shape
[
-
1
])
encoder_attention_mask
=
_expand_mask
(
encoder_attention_mask
,
tgt_len
=
input_shape
[
-
1
])
hidden_states
=
self
.
dropout
(
hidden_states
+
positions
,
training
=
inputs
[
"
training
"
]
)
hidden_states
=
self
.
dropout
(
hidden_states
+
positions
,
training
=
training
)
# decoder layers
# decoder layers
all_hidden_states
=
()
if
inputs
[
"
output_hidden_states
"
]
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
inputs
[
"
output_attentions
"
]
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_cross_attns
=
()
if
(
inputs
[
"
output_attentions
"
]
and
inputs
[
"
encoder_hidden_states
"
]
is
not
None
)
else
None
all_cross_attns
=
()
if
(
output_attentions
and
encoder_hidden_states
is
not
None
)
else
None
present_key_values
=
()
if
inputs
[
"
use_cache
"
]
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
# have to be disabled in other modes than eager.
for
attn_mask
in
[
"head_mask"
,
"cross_attn_head_mask"
]:
for
attn_mask_name
,
attn_mask
in
[
(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)
]:
if
inputs
[
attn_mask
]
is
not
None
and
tf
.
executing_eagerly
():
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
attn_mask
]
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
message
=
f
"The
{
attn_mask
}
should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
attn_mask
]
)[
0
]
}
."
,
message
=
f
"The
{
attn_mask
_name
}
should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
attn_mask
)[
0
]
}
."
,
)
)
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
dropout_probability
=
random
.
uniform
(
0
,
1
)
dropout_probability
=
random
.
uniform
(
0
,
1
)
if
inputs
[
"
training
"
]
and
(
dropout_probability
<
self
.
layerdrop
):
if
training
and
(
dropout_probability
<
self
.
layerdrop
):
continue
continue
past_key_value
=
inputs
[
"
past_key_values
"
]
[
idx
]
if
inputs
[
"
past_key_values
"
]
is
not
None
else
None
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
hidden_states
,
layer_self_attn
,
layer_cross_attn
,
present_key_value
=
decoder_layer
(
hidden_states
,
layer_self_attn
,
layer_cross_attn
,
present_key_value
=
decoder_layer
(
hidden_states
,
hidden_states
,
attention_mask
=
combined_attention_mask
,
attention_mask
=
combined_attention_mask
,
encoder_hidden_states
=
inputs
[
"encoder_hidden_states"
],
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
inputs
[
"encoder_attention_mask"
],
encoder_attention_mask
=
encoder_attention_mask
,
layer_head_mask
=
inputs
[
"head_mask"
][
idx
]
if
inputs
[
"head_mask"
]
is
not
None
else
None
,
layer_head_mask
=
head_mask
[
idx
]
if
head_mask
is
not
None
else
None
,
cross_attn_layer_head_mask
=
inputs
[
"cross_attn_head_mask"
][
idx
]
cross_attn_layer_head_mask
=
cross_attn_head_mask
[
idx
]
if
cross_attn_head_mask
is
not
None
else
None
,
if
inputs
[
"cross_attn_head_mask"
]
is
not
None
else
None
,
past_key_value
=
past_key_value
,
past_key_value
=
past_key_value
,
)
)
if
inputs
[
"
use_cache
"
]
:
if
use_cache
:
present_key_values
+=
(
present_key_value
,)
present_key_values
+=
(
present_key_value
,)
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
all_self_attns
+=
(
layer_self_attn
,)
all_self_attns
+=
(
layer_self_attn
,)
if
inputs
[
"
encoder_hidden_states
"
]
is
not
None
:
if
encoder_hidden_states
is
not
None
:
all_cross_attns
+=
(
layer_cross_attn
,)
all_cross_attns
+=
(
layer_cross_attn
,)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
hidden_states
,
present_key_values
,
all_hidden_states
,
all_self_attns
,
all_cross_attns
return
hidden_states
,
present_key_values
,
all_hidden_states
,
all_self_attns
,
all_cross_attns
else
:
else
:
return
TFBaseModelOutputWithPastAndCrossAttentions
(
return
TFBaseModelOutputWithPastAndCrossAttentions
(
...
@@ -1105,6 +1070,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
...
@@ -1105,6 +1070,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
self
.
encoder
.
set_embed_tokens
(
embed_tokens
)
self
.
encoder
.
set_embed_tokens
(
embed_tokens
)
self
.
decoder
.
set_embed_tokens
(
embed_tokens
)
self
.
decoder
.
set_embed_tokens
(
embed_tokens
)
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -1125,77 +1091,54 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
...
@@ -1125,77 +1091,54 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
training
=
False
,
training
=
False
,
**
kwargs
**
kwargs
):
):
inputs
=
input_processing
(
func
=
self
.
call
,
if
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
config
=
self
.
config
,
use_cache
=
False
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
if
encoder_outputs
is
None
:
encoder_outputs
=
self
.
encoder
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
decoder_head_mask
=
decoder_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
encoder_outputs
=
encoder_outputs
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
decoder_inputs_embeds
=
decoder_inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"decoder_input_ids"
]
is
None
and
inputs
[
"decoder_inputs_embeds"
]
is
None
:
inputs
[
"use_cache"
]
=
False
inputs
[
"output_hidden_states"
]
=
(
inputs
[
"output_hidden_states"
]
if
inputs
[
"output_hidden_states"
]
is
not
None
else
self
.
config
.
output_hidden_states
)
if
inputs
[
"encoder_outputs"
]
is
None
:
inputs
[
"encoder_outputs"
]
=
self
.
encoder
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif
inputs
[
"
return_dict
"
]
and
not
isinstance
(
inputs
[
"
encoder_outputs
"
]
,
TFBaseModelOutput
):
elif
return_dict
and
not
isinstance
(
encoder_outputs
,
TFBaseModelOutput
):
inputs
[
"
encoder_outputs
"
]
=
TFBaseModelOutput
(
encoder_outputs
=
TFBaseModelOutput
(
last_hidden_state
=
inputs
[
"
encoder_outputs
"
]
[
0
],
last_hidden_state
=
encoder_outputs
[
0
],
hidden_states
=
inputs
[
"
encoder_outputs
"
]
[
1
]
if
len
(
inputs
[
"
encoder_outputs
"
]
)
>
1
else
None
,
hidden_states
=
encoder_outputs
[
1
]
if
len
(
encoder_outputs
)
>
1
else
None
,
attentions
=
inputs
[
"
encoder_outputs
"
]
[
2
]
if
len
(
inputs
[
"
encoder_outputs
"
]
)
>
2
else
None
,
attentions
=
encoder_outputs
[
2
]
if
len
(
encoder_outputs
)
>
2
else
None
,
)
)
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif
not
inputs
[
"
return_dict
"
]
and
not
isinstance
(
inputs
[
"
encoder_outputs
"
]
,
tuple
):
elif
not
return_dict
and
not
isinstance
(
encoder_outputs
,
tuple
):
inputs
[
"
encoder_outputs
"
]
=
inputs
[
"
encoder_outputs
"
]
.
to_tuple
()
encoder_outputs
=
encoder_outputs
.
to_tuple
()
decoder_outputs
=
self
.
decoder
(
decoder_outputs
=
self
.
decoder
(
inputs
[
"
decoder_input_ids
"
]
,
decoder_input_ids
,
attention_mask
=
inputs
[
"
decoder_attention_mask
"
]
,
attention_mask
=
decoder_attention_mask
,
encoder_hidden_states
=
inputs
[
"
encoder_outputs
"
]
[
0
],
encoder_hidden_states
=
encoder_outputs
[
0
],
encoder_attention_mask
=
inputs
[
"
attention_mask
"
]
,
encoder_attention_mask
=
attention_mask
,
head_mask
=
inputs
[
"
decoder_head_mask
"
]
,
head_mask
=
decoder_head_mask
,
cross_attn_head_mask
=
inputs
[
"
cross_attn_head_mask
"
]
,
cross_attn_head_mask
=
cross_attn_head_mask
,
past_key_values
=
inputs
[
"
past_key_values
"
]
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs
[
"
decoder_inputs_embeds
"
]
,
inputs_embeds
=
decoder_inputs_embeds
,
use_cache
=
inputs
[
"
use_cache
"
]
,
use_cache
=
use_cache
,
output_attentions
=
inputs
[
"
output_attentions
"
]
,
output_attentions
=
output_attentions
,
output_hidden_states
=
inputs
[
"
output_hidden_states
"
]
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
inputs
[
"
return_dict
"
]
,
return_dict
=
return_dict
,
training
=
inputs
[
"
training
"
]
,
training
=
training
,
)
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
decoder_outputs
+
inputs
[
"
encoder_outputs
"
]
return
decoder_outputs
+
encoder_outputs
return
TFSeq2SeqModelOutput
(
return
TFSeq2SeqModelOutput
(
last_hidden_state
=
decoder_outputs
.
last_hidden_state
,
last_hidden_state
=
decoder_outputs
.
last_hidden_state
,
...
@@ -1203,9 +1146,9 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
...
@@ -1203,9 +1146,9 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
encoder_last_hidden_state
=
inputs
[
"
encoder_outputs
"
]
.
last_hidden_state
,
encoder_last_hidden_state
=
encoder_outputs
.
last_hidden_state
,
encoder_hidden_states
=
inputs
[
"
encoder_outputs
"
]
.
hidden_states
,
encoder_hidden_states
=
encoder_outputs
.
hidden_states
,
encoder_attentions
=
inputs
[
"
encoder_outputs
"
]
.
attentions
,
encoder_attentions
=
encoder_outputs
.
attentions
,
)
)
...
@@ -1225,6 +1168,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
...
@@ -1225,6 +1168,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
def
get_decoder
(
self
):
def
get_decoder
(
self
):
return
self
.
model
.
decoder
return
self
.
model
.
decoder
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
PEGASUS_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
PEGASUS_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1252,9 +1196,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
...
@@ -1252,9 +1196,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
training
=
False
,
training
=
False
,
**
kwargs
**
kwargs
):
):
inputs
=
input_processing
(
func
=
self
.
call
,
outputs
=
self
.
model
(
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
...
@@ -1271,26 +1214,6 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
...
@@ -1271,26 +1214,6 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
decoder_input_ids
=
inputs
[
"decoder_input_ids"
],
decoder_attention_mask
=
inputs
[
"decoder_attention_mask"
],
head_mask
=
inputs
[
"head_mask"
],
decoder_head_mask
=
inputs
[
"decoder_head_mask"
],
cross_attn_head_mask
=
inputs
[
"cross_attn_head_mask"
],
encoder_outputs
=
inputs
[
"encoder_outputs"
],
past_key_values
=
inputs
[
"past_key_values"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
decoder_inputs_embeds
=
inputs
[
"decoder_inputs_embeds"
],
use_cache
=
inputs
[
"use_cache"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
return
outputs
return
outputs
...
@@ -1353,6 +1276,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
...
@@ -1353,6 +1276,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
def
set_bias
(
self
,
value
):
def
set_bias
(
self
,
value
):
self
.
final_logits_bias
=
value
[
"final_logits_bias"
]
self
.
final_logits_bias
=
value
[
"final_logits_bias"
]
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
PEGASUS_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
PEGASUS_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFSeq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
TFSeq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
add_end_docstrings
(
PEGASUS_GENERATION_EXAMPLE
)
@
add_end_docstrings
(
PEGASUS_GENERATION_EXAMPLE
)
...
@@ -1386,17 +1310,28 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
...
@@ -1386,17 +1310,28 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
Returns:
Returns:
"""
"""
inputs
=
input_processing
(
func
=
self
.
call
,
if
labels
is
not
None
:
config
=
self
.
config
,
labels
=
tf
.
where
(
input_ids
=
input_ids
,
labels
==
self
.
config
.
pad_token_id
,
tf
.
fill
(
shape_list
(
labels
),
-
100
),
labels
,
)
use_cache
=
False
if
decoder_input_ids
is
None
:
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
encoder_outputs
=
encoder_outputs
,
decoder_attention_mask
=
decoder_attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
decoder_head_mask
=
decoder_head_mask
,
decoder_head_mask
=
decoder_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
encoder_outputs
=
encoder_outputs
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
decoder_inputs_embeds
=
decoder_inputs_embeds
,
decoder_inputs_embeds
=
decoder_inputs_embeds
,
...
@@ -1404,46 +1339,13 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
...
@@ -1404,46 +1339,13 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"labels"
]
is
not
None
:
inputs
[
"labels"
]
=
tf
.
where
(
inputs
[
"labels"
]
==
self
.
config
.
pad_token_id
,
tf
.
fill
(
shape_list
(
inputs
[
"labels"
]),
-
100
),
inputs
[
"labels"
],
)
inputs
[
"use_cache"
]
=
False
if
inputs
[
"decoder_input_ids"
]
is
None
:
inputs
[
"decoder_input_ids"
]
=
shift_tokens_right
(
inputs
[
"labels"
],
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
outputs
=
self
.
model
(
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
decoder_input_ids
=
inputs
[
"decoder_input_ids"
],
encoder_outputs
=
inputs
[
"encoder_outputs"
],
decoder_attention_mask
=
inputs
[
"decoder_attention_mask"
],
head_mask
=
inputs
[
"head_mask"
],
decoder_head_mask
=
inputs
[
"decoder_head_mask"
],
cross_attn_head_mask
=
inputs
[
"cross_attn_head_mask"
],
past_key_values
=
inputs
[
"past_key_values"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
decoder_inputs_embeds
=
inputs
[
"decoder_inputs_embeds"
],
use_cache
=
inputs
[
"use_cache"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final_logits_bias
lm_logits
=
lm_logits
+
self
.
final_logits_bias
masked_lm_loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
lm_logits
,)
+
outputs
[
1
:]
output
=
(
lm_logits
,)
+
outputs
[
1
:]
return
((
masked_lm_loss
,)
+
output
)
if
masked_lm_loss
is
not
None
else
output
return
((
masked_lm_loss
,)
+
output
)
if
masked_lm_loss
is
not
None
else
output
return
TFSeq2SeqLMOutput
(
return
TFSeq2SeqLMOutput
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment