Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3dc82427
Unverified
Commit
3dc82427
authored
Mar 27, 2022
by
Shamima
Committed by
GitHub
Mar 27, 2022
Browse files
TF: removed inputs_processing and replaced with decorator in lxmert (#16414)
parent
b320d87e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
122 deletions
+63
-122
src/transformers/models/lxmert/modeling_tf_lxmert.py
src/transformers/models/lxmert/modeling_tf_lxmert.py
+63
-122
No files found.
src/transformers/models/lxmert/modeling_tf_lxmert.py
View file @
3dc82427
...
...
@@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple
import
tensorflow
as
tf
from
...activations_tf
import
get_tf_activation
from
...modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
input_processing
,
keras_serializable
,
shape_list
from
...modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
keras_serializable
,
shape_list
,
unpack_inputs
from
...utils
import
(
ModelOutput
,
add_code_sample_docstrings
,
...
...
@@ -671,6 +671,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def
_prune_heads
(
self
,
heads_to_prune
):
raise
NotImplementedError
@
unpack_inputs
def
call
(
self
,
input_ids
=
None
,
...
...
@@ -686,51 +687,33 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
training
=
False
,
**
kwargs
,
):
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
visual_feats
=
visual_feats
,
visual_pos
=
visual_pos
,
attention_mask
=
attention_mask
,
visual_attention_mask
=
visual_attention_mask
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs
[
"
visual_pos
"
]
is
None
or
inputs
[
"
visual_feats
"
]
is
None
:
if
visual_pos
is
None
or
visual_feats
is
None
:
raise
ValueError
(
"visual_feats and visual_pos cannot be `None` in LXMERT's `call` method."
)
if
inputs
[
"
attention_mask
"
]
is
None
:
inputs
[
"
attention_mask
"
]
=
tf
.
fill
(
input_shape
,
1
)
if
attention_mask
is
None
:
attention_mask
=
tf
.
fill
(
input_shape
,
1
)
if
inputs
[
"
token_type_ids
"
]
is
None
:
inputs
[
"
token_type_ids
"
]
=
tf
.
fill
(
input_shape
,
0
)
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
fill
(
input_shape
,
0
)
# Positional Word Embeddings
embedding_output
=
self
.
embeddings
(
inputs
[
"input_ids"
],
inputs
[
"token_type_ids"
],
inputs
[
"inputs_embeds"
],
training
=
inputs
[
"training"
]
)
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
,
inputs_embeds
,
training
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask
=
tf
.
reshape
(
inputs
[
"
attention_mask
"
]
,
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
]))
extended_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
]))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
...
...
@@ -743,13 +726,9 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
ten_thousand_cst
=
tf
.
constant
(
-
10000.0
,
dtype
=
embedding_output
.
dtype
)
extended_attention_mask
=
tf
.
multiply
(
tf
.
subtract
(
one_cst
,
extended_attention_mask
),
ten_thousand_cst
)
if
inputs
[
"visual_attention_mask"
]
is
not
None
:
extended_visual_attention_mask
=
tf
.
reshape
(
inputs
[
"visual_attention_mask"
],
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
])
)
extended_visual_attention_mask
=
tf
.
expand_dims
(
tf
.
expand_dims
(
inputs
[
"visual_attention_mask"
],
axis
=
1
),
axis
=
1
)
if
visual_attention_mask
is
not
None
:
extended_visual_attention_mask
=
tf
.
reshape
(
visual_attention_mask
,
(
input_shape
[
0
],
1
,
1
,
input_shape
[
1
]))
extended_visual_attention_mask
=
tf
.
expand_dims
(
tf
.
expand_dims
(
visual_attention_mask
,
axis
=
1
),
axis
=
1
)
extended_visual_attention_mask
=
tf
.
cast
(
extended_visual_attention_mask
,
dtype
=
embedding_output
.
dtype
)
extended_visual_attention_mask
=
tf
.
multiply
(
...
...
@@ -762,18 +741,18 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
encoder_outputs
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
inputs
[
"
visual_feats
"
]
,
inputs
[
"
visual_pos
"
]
,
visual_feats
,
visual_pos
,
extended_visual_attention_mask
,
output_attentions
=
inputs
[
"output_attentions"
]
,
training
=
inputs
[
"training"
]
,
output_attentions
,
training
,
)
visual_encoder_outputs
,
lang_encoder_outputs
=
encoder_outputs
[:
2
]
vision_hidden_states
=
visual_encoder_outputs
[
0
]
language_hidden_states
=
lang_encoder_outputs
[
0
]
all_attentions
=
()
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
language_attentions
=
lang_encoder_outputs
[
1
]
vision_attentions
=
visual_encoder_outputs
[
1
]
cross_encoder_attentions
=
encoder_outputs
[
2
]
...
...
@@ -783,24 +762,24 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
cross_encoder_attentions
,
)
hidden_states
=
(
language_hidden_states
,
vision_hidden_states
)
if
inputs
[
"
output_hidden_states
"
]
else
()
hidden_states
=
(
language_hidden_states
,
vision_hidden_states
)
if
output_hidden_states
else
()
visual_output
=
vision_hidden_states
[
-
1
]
lang_output
=
language_hidden_states
[
-
1
]
pooled_output
=
self
.
pooler
(
lang_output
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
(
lang_output
,
visual_output
,
pooled_output
)
+
hidden_states
+
all_attentions
return
TFLxmertModelOutput
(
pooled_output
=
pooled_output
,
language_output
=
lang_output
,
vision_output
=
visual_output
,
language_hidden_states
=
language_hidden_states
if
inputs
[
"
output_hidden_states
"
]
else
None
,
vision_hidden_states
=
vision_hidden_states
if
inputs
[
"
output_hidden_states
"
]
else
None
,
language_attentions
=
language_attentions
if
inputs
[
"
output_attentions
"
]
else
None
,
vision_attentions
=
vision_attentions
if
inputs
[
"
output_attentions
"
]
else
None
,
cross_encoder_attentions
=
cross_encoder_attentions
if
inputs
[
"
output_attentions
"
]
else
None
,
language_hidden_states
=
language_hidden_states
if
output_hidden_states
else
None
,
vision_hidden_states
=
vision_hidden_states
if
output_hidden_states
else
None
,
language_attentions
=
language_attentions
if
output_attentions
else
None
,
vision_attentions
=
vision_attentions
if
output_attentions
else
None
,
cross_encoder_attentions
=
cross_encoder_attentions
if
output_attentions
else
None
,
)
...
...
@@ -946,6 +925,7 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
lxmert
=
TFLxmertMainLayer
(
config
,
name
=
"lxmert"
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
LXMERT_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
...
...
@@ -968,34 +948,18 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
training
=
False
,
**
kwargs
,
):
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
visual_feats
=
visual_feats
,
visual_pos
=
visual_pos
,
attention_mask
=
attention_mask
,
visual_attention_mask
=
visual_attention_mask
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
lxmert
(
input_ids
=
inputs
[
"input_ids"
]
,
visual_feats
=
inputs
[
"visual_feats"
]
,
visual_pos
=
inputs
[
"visual_pos"
]
,
attention_mask
=
inputs
[
"attention_mask"
]
,
visual_attention_mask
=
inputs
[
"visual_attention_mask"
]
,
token_type_ids
=
inputs
[
"token_type_ids"
]
,
inputs_embeds
=
inputs
[
"inputs_embeds"
]
,
output_attentions
=
inputs
[
"output_attentions"
]
,
output_hidden_states
=
inputs
[
"output_hidden_states"
]
,
return_dict
=
inputs
[
"return_dict"
]
,
training
=
inputs
[
"training"
]
,
input_ids
,
visual_feats
,
visual_pos
,
attention_mask
,
visual_attention_mask
,
token_type_ids
,
inputs_embeds
,
output_attentions
,
output_hidden_states
,
return_dict
,
training
,
)
return
outputs
...
...
@@ -1298,6 +1262,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
warnings
.
warn
(
"The method get_prefix_bias_name is deprecated. Please use `get_bias` instead."
,
FutureWarning
)
return
self
.
name
+
"/"
+
self
.
cls
.
name
+
"/"
+
self
.
cls
.
predictions
.
name
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
LXMERT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFLxmertForPreTrainingOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
...
...
@@ -1339,38 +1304,19 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
Returns:
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
visual_feats
=
visual_feats
,
visual_pos
=
visual_pos
,
attention_mask
=
attention_mask
,
visual_attention_mask
=
visual_attention_mask
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
masked_lm_labels
=
masked_lm_labels
,
obj_labels
=
obj_labels
,
matched_label
=
matched_label
,
ans
=
ans
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
lxmert_output
=
self
.
lxmert
(
input_ids
=
inputs
[
"input_ids"
]
,
visual_feats
=
inputs
[
"visual_feats"
]
,
visual_pos
=
inputs
[
"visual_pos"
]
,
attention_mask
=
inputs
[
"attention_mask"
]
,
visual_attention_mask
=
inputs
[
"visual_attention_mask"
]
,
token_type_ids
=
inputs
[
"token_type_ids"
]
,
inputs_embeds
=
inputs
[
"inputs_embeds"
]
,
output_attentions
=
inputs
[
"output_attentions"
]
,
output_hidden_states
=
inputs
[
"output_hidden_states"
]
,
return_dict
=
inputs
[
"return_dict"
]
,
training
=
inputs
[
"training"
]
,
input_ids
,
visual_feats
,
visual_pos
,
attention_mask
,
visual_attention_mask
,
token_type_ids
,
inputs_embeds
,
output_attentions
,
output_hidden_states
,
return_dict
,
training
,
)
lang_output
,
visual_output
,
pooled_output
=
(
...
...
@@ -1386,34 +1332,29 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_loss
=
(
None
if
(
inputs
[
"masked_lm_labels"
]
is
None
and
inputs
[
"matched_label"
]
is
None
and
inputs
[
"obj_labels"
]
is
None
and
inputs
[
"ans"
]
is
None
)
if
(
masked_lm_labels
is
None
and
matched_label
is
None
and
obj_labels
is
None
and
ans
is
None
)
else
tf
.
constant
(
0.0
)
)
losses
=
()
if
inputs
[
"
masked_lm_labels
"
]
is
not
None
and
self
.
task_mask_lm
:
if
masked_lm_labels
is
not
None
and
self
.
task_mask_lm
:
masked_lm_loss
=
self
.
loss_fcts
[
"ce"
](
tf
.
reshape
(
inputs
[
"
masked_lm_labels
"
]
,
[
-
1
]),
tf
.
reshape
(
masked_lm_labels
,
[
-
1
]),
tf
.
reshape
(
lang_prediction_scores
,
[
-
1
,
self
.
config
.
vocab_size
]),
)
total_loss
+=
masked_lm_loss
losses
+=
(
masked_lm_loss
,)
if
inputs
[
"
matched_label
"
]
is
not
None
and
self
.
task_matched
:
if
matched_label
is
not
None
and
self
.
task_matched
:
matched_loss
=
self
.
loss_fcts
[
"ce"
](
tf
.
reshape
(
inputs
[
"
matched_label
"
]
,
[
-
1
]),
tf
.
reshape
(
matched_label
,
[
-
1
]),
tf
.
reshape
(
cross_relationship_score
,
[
-
1
,
2
]),
)
total_loss
+=
matched_loss
losses
+=
(
matched_loss
,)
if
inputs
[
"
obj_labels
"
]
is
not
None
and
self
.
task_obj_predict
:
if
obj_labels
is
not
None
and
self
.
task_obj_predict
:
total_visn_loss
=
0.0
visn_prediction_scores_dict
=
self
.
obj_predict_head
(
visual_output
)
for
key
,
key_info
in
self
.
visual_losses
.
items
():
label
,
mask_conf
=
inputs
[
"
obj_labels
"
]
[
key
]
label
,
mask_conf
=
obj_labels
[
key
]
output_dim
=
key_info
[
"num"
]
loss_fct_name
=
key_info
[
"loss"
]
label_shape
=
key_info
[
"shape"
]
...
...
@@ -1431,7 +1372,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_visn_loss
+=
visn_loss
losses
+=
(
visn_loss
,)
total_loss
+=
total_visn_loss
if
inputs
[
"
ans
"
]
is
not
None
and
self
.
task_qa
:
if
ans
is
not
None
and
self
.
task_qa
:
answer_loss
=
self
.
loss_fcts
[
"ce"
](
tf
.
reshape
(
ans
,
[
-
1
]),
tf
.
reshape
(
answer_score
,
[
-
1
,
self
.
num_qa_labels
])
)
...
...
@@ -1444,7 +1385,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
losses
+=
(
answer_loss
,)
# return total_loss, tf.stack(losses)[tf.new_axis, ...], answer_score.detach()
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
lang_prediction_scores
,
cross_relationship_score
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment