Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
64582361
Unverified
Commit
64582361
authored
Mar 14, 2022
by
Kamal Raj
Committed by
GitHub
Mar 14, 2022
Browse files
TF Electra - clearer model variable naming (#16143)
parent
37793259
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
81 additions
and
216 deletions
+81
-216
src/transformers/models/electra/modeling_tf_electra.py
src/transformers/models/electra/modeling_tf_electra.py
+81
-216
No files found.
src/transformers/models/electra/modeling_tf_electra.py
View file @
64582361
...
...
@@ -50,8 +50,8 @@ from ...modeling_tf_utils import (
TFSequenceSummary
,
TFTokenClassificationLoss
,
get_initializer
,
input_processing
,
keras_serializable
,
unpack_inputs
,
)
from
...tf_utils
import
shape_list
from
...utils
import
logging
...
...
@@ -702,6 +702,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
return
head_mask
@
unpack_inputs
def
call
(
self
,
input_ids
:
Optional
[
TFModelInputType
]
=
None
,
...
...
@@ -720,77 +721,55 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
training
:
Optional
[
bool
]
=
False
,
**
kwargs
,
)
->
Union
[
TFBaseModelOutputWithPastAndCrossAttentions
,
Tuple
[
tf
.
Tensor
]]:
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
not
self
.
config
.
is_decoder
:
inputs
[
"
use_cache
"
]
=
False
use_cache
=
False
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
batch_size
,
seq_length
=
input_shape
if
inputs
[
"
past_key_values
"
]
is
None
:
if
past_key_values
is
None
:
past_key_values_length
=
0
inputs
[
"
past_key_values
"
]
=
[
None
]
*
len
(
self
.
encoder
.
layer
)
past_key_values
=
[
None
]
*
len
(
self
.
encoder
.
layer
)
else
:
past_key_values_length
=
shape_list
(
inputs
[
"
past_key_values
"
]
[
0
][
0
])[
-
2
]
past_key_values_length
=
shape_list
(
past_key_values
[
0
][
0
])[
-
2
]
if
inputs
[
"
attention_mask
"
]
is
None
:
inputs
[
"
attention_mask
"
]
=
tf
.
fill
(
dims
=
(
batch_size
,
seq_length
+
past_key_values_length
),
value
=
1
)
if
attention_mask
is
None
:
attention_mask
=
tf
.
fill
(
dims
=
(
batch_size
,
seq_length
+
past_key_values_length
),
value
=
1
)
if
inputs
[
"
token_type_ids
"
]
is
None
:
inputs
[
"
token_type_ids
"
]
=
tf
.
fill
(
dims
=
input_shape
,
value
=
0
)
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
fill
(
dims
=
input_shape
,
value
=
0
)
hidden_states
=
self
.
embeddings
(
input_ids
=
inputs
[
"
input_ids
"
]
,
position_ids
=
inputs
[
"
position_ids
"
]
,
token_type_ids
=
inputs
[
"
token_type_ids
"
]
,
inputs_embeds
=
inputs
[
"inputs
_embeds
"
]
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
past_key_values_length
=
past_key_values_length
,
training
=
inputs
[
"
training
"
]
,
training
=
training
,
)
extended_attention_mask
=
self
.
get_extended_attention_mask
(
inputs
[
"
attention_mask
"
]
,
input_shape
,
hidden_states
.
dtype
,
past_key_values_length
attention_mask
,
input_shape
,
hidden_states
.
dtype
,
past_key_values_length
)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if
self
.
is_decoder
and
inputs
[
"
encoder_attention_mask
"
]
is
not
None
:
if
self
.
is_decoder
and
encoder_attention_mask
is
not
None
:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs
[
"encoder_attention_mask"
]
=
tf
.
cast
(
inputs
[
"encoder_attention_mask"
],
dtype
=
extended_attention_mask
.
dtype
)
num_dims_encoder_attention_mask
=
len
(
shape_list
(
inputs
[
"encoder_attention_mask"
]))
encoder_attention_mask
=
tf
.
cast
(
encoder_attention_mask
,
dtype
=
extended_attention_mask
.
dtype
)
num_dims_encoder_attention_mask
=
len
(
shape_list
(
encoder_attention_mask
))
if
num_dims_encoder_attention_mask
==
3
:
encoder_extended_attention_mask
=
inputs
[
"
encoder_attention_mask
"
]
[:,
None
,
:,
:]
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
:,
:]
if
num_dims_encoder_attention_mask
==
2
:
encoder_extended_attention_mask
=
inputs
[
"
encoder_attention_mask
"
]
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
...
...
@@ -801,23 +780,23 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
else
:
encoder_extended_attention_mask
=
None
inputs
[
"
head_mask
"
]
=
self
.
get_head_mask
(
inputs
[
"
head_mask
"
]
)
head_mask
=
self
.
get_head_mask
(
head_mask
)
if
hasattr
(
self
,
"embeddings_project"
):
hidden_states
=
self
.
embeddings_project
(
hidden_states
,
training
=
inputs
[
"
training
"
]
)
hidden_states
=
self
.
embeddings_project
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
encoder
(
hidden_states
=
hidden_states
,
attention_mask
=
extended_attention_mask
,
head_mask
=
inputs
[
"
head_mask
"
]
,
encoder_hidden_states
=
inputs
[
"
encoder_hidden_states
"
]
,
head_mask
=
head_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_extended_attention_mask
,
past_key_values
=
inputs
[
"
past_key_values
"
]
,
use_cache
=
inputs
[
"
use_cache
"
]
,
output_attentions
=
inputs
[
"
output_attentions
"
]
,
output_hidden_states
=
inputs
[
"
output_hidden_states
"
]
,
return_dict
=
inputs
[
"
return_dict
"
]
,
training
=
inputs
[
"
training
"
]
,
past_key_values
=
past_key_values
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
)
return
hidden_states
...
...
@@ -950,6 +929,7 @@ class TFElectraModel(TFElectraPreTrainedModel):
self
.
electra
=
TFElectraMainLayer
(
config
,
name
=
"electra"
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
ELECTRA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
...
...
@@ -995,40 +975,21 @@ class TFElectraModel(TFElectraPreTrainedModel):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). Set to `False` during training, `True` during generation
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
outputs
=
self
.
electra
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
use_cache
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
electra
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
position_ids
=
inputs
[
"position_ids"
],
head_mask
=
inputs
[
"head_mask"
],
encoder_hidden_states
=
inputs
[
"encoder_hidden_states"
],
encoder_attention_mask
=
inputs
[
"encoder_attention_mask"
],
past_key_values
=
inputs
[
"past_key_values"
],
use_cache
=
inputs
[
"use_cache"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
return
outputs
...
...
@@ -1067,6 +1028,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
self
.
electra
=
TFElectraMainLayer
(
config
,
name
=
"electra"
)
self
.
discriminator_predictions
=
TFElectraDiscriminatorPredictions
(
config
,
name
=
"discriminator_predictions"
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
ELECTRA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
replace_return_docstrings
(
output_type
=
TFElectraForPreTrainingOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
...
...
@@ -1098,9 +1060,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
>>> outputs = model(input_ids)
>>> scores = outputs[0]
```"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
discriminator_hidden_states
=
self
.
electra
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
...
...
@@ -1111,24 +1071,11 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
discriminator_hidden_states
=
self
.
electra
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
position_ids
=
inputs
[
"position_ids"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
discriminator_sequence_output
=
discriminator_hidden_states
[
0
]
logits
=
self
.
discriminator_predictions
(
discriminator_sequence_output
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
(
logits
,)
+
discriminator_hidden_states
[
1
:]
return
TFElectraForPreTrainingOutput
(
...
...
@@ -1212,6 +1159,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
warnings
.
warn
(
"The method get_prefix_bias_name is deprecated. Please use `get_bias` instead."
,
FutureWarning
)
return
self
.
name
+
"/"
+
self
.
generator_lm_head
.
name
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
ELECTRA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
...
...
@@ -1240,9 +1188,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
generator_hidden_states
=
self
.
electra
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
...
...
@@ -1252,28 +1198,14 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
kwargs_call
=
kwargs
,
)
generator_hidden_states
=
self
.
electra
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
position_ids
=
inputs
[
"position_ids"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
generator_sequence_output
=
generator_hidden_states
[
0
]
prediction_scores
=
self
.
generator_predictions
(
generator_sequence_output
,
training
=
inputs
[
"
training
"
]
)
prediction_scores
=
self
.
generator_lm_head
(
prediction_scores
,
training
=
inputs
[
"
training
"
]
)
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
prediction_scores
)
prediction_scores
=
self
.
generator_predictions
(
generator_sequence_output
,
training
=
training
)
prediction_scores
=
self
.
generator_lm_head
(
prediction_scores
,
training
=
training
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
prediction_scores
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
prediction_scores
,)
+
generator_hidden_states
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
...
@@ -1337,6 +1269,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
self
.
electra
=
TFElectraMainLayer
(
config
,
name
=
"electra"
)
self
.
classifier
=
TFElectraClassificationHead
(
config
,
name
=
"classifier"
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
ELECTRA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
...
...
@@ -1365,9 +1298,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
outputs
=
self
.
electra
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
...
...
@@ -1377,26 +1308,12 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
electra
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
position_ids
=
inputs
[
"position_ids"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
logits
=
self
.
classifier
(
outputs
[
0
])
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
logits
,)
+
outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
...
@@ -1445,6 +1362,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
"""
return
{
"input_ids"
:
tf
.
constant
(
MULTIPLE_CHOICE_DUMMY_INPUTS
)}
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
ELECTRA_INPUTS_DOCSTRING
.
format
(
"batch_size, num_choices, sequence_length"
))
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
...
...
@@ -1472,43 +1390,21 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"
input_ids
"
]
is
not
None
:
num_choices
=
shape_list
(
inputs
[
"
input_ids
"
]
)[
1
]
seq_length
=
shape_list
(
inputs
[
"
input_ids
"
]
)[
2
]
if
input_ids
is
not
None
:
num_choices
=
shape_list
(
input_ids
)[
1
]
seq_length
=
shape_list
(
input_ids
)[
2
]
else
:
num_choices
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[
1
]
seq_length
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[
2
]
num_choices
=
shape_list
(
inputs_embeds
)[
1
]
seq_length
=
shape_list
(
inputs_embeds
)[
2
]
flat_input_ids
=
tf
.
reshape
(
inputs
[
"input_ids"
],
(
-
1
,
seq_length
))
if
inputs
[
"input_ids"
]
is
not
None
else
None
flat_attention_mask
=
(
tf
.
reshape
(
inputs
[
"attention_mask"
],
(
-
1
,
seq_length
))
if
inputs
[
"attention_mask"
]
is
not
None
else
None
)
flat_token_type_ids
=
(
tf
.
reshape
(
inputs
[
"token_type_ids"
],
(
-
1
,
seq_length
))
if
inputs
[
"token_type_ids"
]
is
not
None
else
None
)
flat_position_ids
=
(
tf
.
reshape
(
inputs
[
"position_ids"
],
(
-
1
,
seq_length
))
if
inputs
[
"position_ids"
]
is
not
None
else
None
)
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
if
input_ids
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
flat_position_ids
=
tf
.
reshape
(
position_ids
,
(
-
1
,
seq_length
))
if
position_ids
is
not
None
else
None
flat_inputs_embeds
=
(
tf
.
reshape
(
inputs
[
"inputs
_embeds
"
]
,
(
-
1
,
seq_length
,
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[
3
]))
if
inputs
[
"
inputs_embeds
"
]
is
not
None
tf
.
reshape
(
inputs_embeds
,
(
-
1
,
seq_length
,
shape_list
(
inputs_embeds
)[
3
]))
if
inputs_embeds
is
not
None
else
None
)
outputs
=
self
.
electra
(
...
...
@@ -1516,19 +1412,19 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
attention_mask
=
flat_attention_mask
,
token_type_ids
=
flat_token_type_ids
,
position_ids
=
flat_position_ids
,
head_mask
=
inputs
[
"
head_mask
"
]
,
head_mask
=
head_mask
,
inputs_embeds
=
flat_inputs_embeds
,
output_attentions
=
inputs
[
"
output_attentions
"
]
,
output_hidden_states
=
inputs
[
"
output_hidden_states
"
]
,
return_dict
=
inputs
[
"
return_dict
"
]
,
training
=
inputs
[
"
training
"
]
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
)
logits
=
self
.
sequence_summary
(
outputs
[
0
])
logits
=
self
.
classifier
(
logits
)
reshaped_logits
=
tf
.
reshape
(
logits
,
(
-
1
,
num_choices
))
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
reshaped_logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
reshaped_logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
reshaped_logits
,)
+
outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
...
@@ -1584,6 +1480,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"classifier"
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
ELECTRA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
...
...
@@ -1610,9 +1507,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
discriminator_hidden_states
=
self
.
electra
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
...
...
@@ -1622,28 +1517,14 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
kwargs_call
=
kwargs
,
)
discriminator_hidden_states
=
self
.
electra
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
position_ids
=
inputs
[
"position_ids"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
discriminator_sequence_output
=
discriminator_hidden_states
[
0
]
discriminator_sequence_output
=
self
.
dropout
(
discriminator_sequence_output
)
logits
=
self
.
classifier
(
discriminator_sequence_output
)
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
logits
,)
+
discriminator_hidden_states
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
...
@@ -1680,6 +1561,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"qa_outputs"
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
ELECTRA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
...
...
@@ -1713,9 +1595,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
discriminator_hidden_states
=
self
.
electra
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
...
...
@@ -1725,22 +1605,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
start_positions
=
start_positions
,
end_positions
=
end_positions
,
training
=
training
,
kwargs_call
=
kwargs
,
)
discriminator_hidden_states
=
self
.
electra
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
position_ids
=
inputs
[
"position_ids"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
discriminator_sequence_output
=
discriminator_hidden_states
[
0
]
logits
=
self
.
qa_outputs
(
discriminator_sequence_output
)
...
...
@@ -1749,12 +1614,12 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
end_logits
=
tf
.
squeeze
(
end_logits
,
axis
=-
1
)
loss
=
None
if
inputs
[
"
start_positions
"
]
is
not
None
and
inputs
[
"
end_positions
"
]
is
not
None
:
labels
=
{
"start_position"
:
inputs
[
"
start_positions
"
]
}
labels
[
"end_position"
]
=
inputs
[
"
end_positions
"
]
if
start_positions
is
not
None
and
end_positions
is
not
None
:
labels
=
{
"start_position"
:
start_positions
}
labels
[
"end_position"
]
=
end_positions
loss
=
self
.
hf_compute_loss
(
labels
,
(
start_logits
,
end_logits
))
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
start_logits
,
end_logits
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment