Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
015de6f0
Unverified
Commit
015de6f0
authored
Mar 15, 2022
by
Kamal Raj
Committed by
GitHub
Mar 15, 2022
Browse files
TF clearer model variable naming: xlnet (#16150)
parent
a23a7c0c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
108 additions
and
257 deletions
+108
-257
src/transformers/models/xlnet/modeling_tf_xlnet.py
src/transformers/models/xlnet/modeling_tf_xlnet.py
+108
-257
No files found.
src/transformers/models/xlnet/modeling_tf_xlnet.py
View file @
015de6f0
...
@@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
...
@@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings
,
TFSharedEmbeddings
,
TFTokenClassificationLoss
,
TFTokenClassificationLoss
,
get_initializer
,
get_initializer
,
input_processing
,
keras_serializable
,
keras_serializable
,
unpack_inputs
,
)
)
from
...tf_utils
import
shape_list
from
...tf_utils
import
shape_list
from
...utils
import
logging
from
...utils
import
logging
...
@@ -578,6 +578,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -578,6 +578,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return
pos_emb
return
pos_emb
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -596,63 +597,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -596,63 +597,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
training
=
False
,
training
=
False
,
**
kwargs
,
**
kwargs
,
):
):
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
token_type_ids
=
token_type_ids
,
input_mask
=
input_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
use_mems
=
use_mems
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
training
and
inputs
[
"
use_mems
"
]
is
None
:
if
training
and
use_mems
is
None
:
inputs
[
"
use_mems
"
]
=
self
.
use_mems_train
use_mems
=
self
.
use_mems_train
else
:
else
:
inputs
[
"
use_mems
"
]
=
self
.
use_mems_eval
use_mems
=
self
.
use_mems_eval
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
# so we move here the first dimension (batch) to the end
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
elif
input_ids
is
not
None
:
inputs
[
"
input_ids
"
]
=
tf
.
transpose
(
inputs
[
"
input_ids
"
]
,
perm
=
(
1
,
0
))
input_ids
=
tf
.
transpose
(
input_ids
,
perm
=
(
1
,
0
))
qlen
,
bsz
=
shape_list
(
inputs
[
"
input_ids
"
]
)[:
2
]
qlen
,
bsz
=
shape_list
(
input_ids
)[:
2
]
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
elif
inputs_embeds
is
not
None
:
inputs
[
"
inputs_embeds
"
]
=
tf
.
transpose
(
inputs
[
"inputs
_embeds
"
]
,
perm
=
(
1
,
0
,
2
))
inputs_embeds
=
tf
.
transpose
(
inputs_embeds
,
perm
=
(
1
,
0
,
2
))
qlen
,
bsz
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
2
]
qlen
,
bsz
=
shape_list
(
inputs_embeds
)[:
2
]
else
:
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
inputs
[
"token_type_ids"
]
=
(
token_type_ids
=
tf
.
transpose
(
token_type_ids
,
perm
=
(
1
,
0
))
if
token_type_ids
is
not
None
else
None
tf
.
transpose
(
inputs
[
"token_type_ids"
],
perm
=
(
1
,
0
))
if
inputs
[
"token_type_ids"
]
is
not
None
else
None
input_mask
=
tf
.
transpose
(
input_mask
,
perm
=
(
1
,
0
))
if
input_mask
is
not
None
else
None
)
attention_mask
=
tf
.
transpose
(
attention_mask
,
perm
=
(
1
,
0
))
if
attention_mask
is
not
None
else
None
inputs
[
"input_mask"
]
=
(
perm_mask
=
tf
.
transpose
(
perm_mask
,
perm
=
(
1
,
2
,
0
))
if
perm_mask
is
not
None
else
None
tf
.
transpose
(
inputs
[
"input_mask"
],
perm
=
(
1
,
0
))
if
inputs
[
"input_mask"
]
is
not
None
else
None
target_mapping
=
tf
.
transpose
(
target_mapping
,
perm
=
(
1
,
2
,
0
))
if
target_mapping
is
not
None
else
None
)
inputs
[
"attention_mask"
]
=
(
tf
.
transpose
(
inputs
[
"attention_mask"
],
perm
=
(
1
,
0
))
if
inputs
[
"attention_mask"
]
is
not
None
else
None
)
inputs
[
"perm_mask"
]
=
(
tf
.
transpose
(
inputs
[
"perm_mask"
],
perm
=
(
1
,
2
,
0
))
if
inputs
[
"perm_mask"
]
is
not
None
else
None
)
inputs
[
"target_mapping"
]
=
(
tf
.
transpose
(
inputs
[
"target_mapping"
],
perm
=
(
1
,
2
,
0
))
if
inputs
[
"target_mapping"
]
is
not
None
else
None
)
mlen
=
shape_list
(
inputs
[
"
mems
"
]
[
0
])[
0
]
if
inputs
[
"
mems
"
]
is
not
None
and
inputs
[
"
mems
"
]
[
0
]
is
not
None
else
0
mlen
=
shape_list
(
mems
[
0
])[
0
]
if
mems
is
not
None
and
mems
[
0
]
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
# Attention mask
# Attention mask
...
@@ -666,19 +638,19 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -666,19 +638,19 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise
ValueError
(
f
"Unsupported attention type:
{
self
.
attn_type
}
"
)
raise
ValueError
(
f
"Unsupported attention type:
{
self
.
attn_type
}
"
)
# data mask: input mask & perm mask
# data mask: input mask & perm mask
assert
inputs
[
"
input_mask
"
]
is
None
or
inputs
[
"
attention_mask
"
]
is
None
,
(
assert
input_mask
is
None
or
attention_mask
is
None
,
(
"You can only use one of input_mask (uses 1 for padding) "
"You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
)
)
if
inputs
[
"
input_mask
"
]
is
None
and
inputs
[
"
attention_mask
"
]
is
not
None
:
if
input_mask
is
None
and
attention_mask
is
not
None
:
one_cst
=
tf
.
constant
(
1.0
)
one_cst
=
tf
.
constant
(
1.0
)
inputs
[
"
input_mask
"
]
=
1.0
-
tf
.
cast
(
inputs
[
"
attention_mask
"
]
,
dtype
=
one_cst
.
dtype
)
input_mask
=
1.0
-
tf
.
cast
(
attention_mask
,
dtype
=
one_cst
.
dtype
)
if
inputs
[
"
input_mask
"
]
is
not
None
and
inputs
[
"
perm_mask
"
]
is
not
None
:
if
input_mask
is
not
None
and
perm_mask
is
not
None
:
data_mask
=
inputs
[
"
input_mask
"
]
[
None
]
+
inputs
[
"
perm_mask
"
]
data_mask
=
input_mask
[
None
]
+
perm_mask
elif
inputs
[
"
input_mask
"
]
is
not
None
and
inputs
[
"
perm_mask
"
]
is
None
:
elif
input_mask
is
not
None
and
perm_mask
is
None
:
data_mask
=
inputs
[
"
input_mask
"
]
[
None
]
data_mask
=
input_mask
[
None
]
elif
inputs
[
"
input_mask
"
]
is
None
and
inputs
[
"
perm_mask
"
]
is
not
None
:
elif
input_mask
is
None
and
perm_mask
is
not
None
:
data_mask
=
inputs
[
"
perm_mask
"
]
data_mask
=
perm_mask
else
:
else
:
data_mask
=
None
data_mask
=
None
...
@@ -704,33 +676,33 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -704,33 +676,33 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
non_tgt_mask
=
None
non_tgt_mask
=
None
# Word embeddings and prepare h & g hidden states
# Word embeddings and prepare h & g hidden states
if
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
inputs_embeds
is
not
None
:
word_emb_k
=
inputs
[
"
inputs_embeds
"
]
word_emb_k
=
inputs_embeds
else
:
else
:
word_emb_k
=
self
.
word_embedding
(
inputs
[
"
input_ids
"
]
)
word_emb_k
=
self
.
word_embedding
(
input_ids
)
output_h
=
self
.
dropout
(
word_emb_k
,
training
=
inputs
[
"
training
"
]
)
output_h
=
self
.
dropout
(
word_emb_k
,
training
=
training
)
if
inputs
[
"
target_mapping
"
]
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
tf
.
tile
(
self
.
mask_emb
,
[
shape_list
(
inputs
[
"
target_mapping
"
]
)[
0
],
bsz
,
1
])
word_emb_q
=
tf
.
tile
(
self
.
mask_emb
,
[
shape_list
(
target_mapping
)[
0
],
bsz
,
1
])
# else: # We removed the inp_q input which was same as target mapping
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g
=
self
.
dropout
(
word_emb_q
,
training
=
inputs
[
"
training
"
]
)
output_g
=
self
.
dropout
(
word_emb_q
,
training
=
training
)
else
:
else
:
output_g
=
None
output_g
=
None
# Segment embedding
# Segment embedding
if
inputs
[
"
token_type_ids
"
]
is
not
None
:
if
token_type_ids
is
not
None
:
# Convert `token_type_ids` to one-hot `seg_mat`
# Convert `token_type_ids` to one-hot `seg_mat`
if
mlen
>
0
:
if
mlen
>
0
:
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
inputs
[
"
token_type_ids
"
]
.
dtype
)
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
token_type_ids
.
dtype
)
cat_ids
=
tf
.
concat
([
mem_pad
,
inputs
[
"
token_type_ids
"
]
],
0
)
cat_ids
=
tf
.
concat
([
mem_pad
,
token_type_ids
],
0
)
else
:
else
:
cat_ids
=
inputs
[
"
token_type_ids
"
]
cat_ids
=
token_type_ids
# `1` indicates not in the same segment [qlen x klen x bsz]
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
tf
.
cast
(
seg_mat
=
tf
.
cast
(
tf
.
logical_not
(
tf
.
equal
(
inputs
[
"
token_type_ids
"
]
[:,
None
],
cat_ids
[
None
,
:])),
tf
.
logical_not
(
tf
.
equal
(
token_type_ids
[:,
None
],
cat_ids
[
None
,
:])),
dtype
=
inputs
[
"
token_type_ids
"
]
.
dtype
,
dtype
=
token_type_ids
.
dtype
,
)
)
seg_mat
=
tf
.
one_hot
(
seg_mat
,
2
)
seg_mat
=
tf
.
one_hot
(
seg_mat
,
2
)
else
:
else
:
...
@@ -738,29 +710,29 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -738,29 +710,29 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Positional encoding
# Positional encoding
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
,
training
=
inputs
[
"
training
"
]
)
pos_emb
=
self
.
dropout
(
pos_emb
,
training
=
training
)
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if
inputs
[
"
head_mask
"
]
is
not
None
:
if
head_mask
is
not
None
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
inputs
[
"
head_mask
"
]
=
[
None
]
*
self
.
n_layer
head_mask
=
[
None
]
*
self
.
n_layer
new_mems
=
()
new_mems
=
()
if
inputs
[
"
mems
"
]
is
None
:
if
mems
is
None
:
inputs
[
"
mems
"
]
=
[
None
]
*
len
(
self
.
layer
)
mems
=
[
None
]
*
len
(
self
.
layer
)
attentions
=
[]
if
inputs
[
"
output_attentions
"
]
else
None
attentions
=
[]
if
output_attentions
else
None
hidden_states
=
[]
if
inputs
[
"
output_hidden_states
"
]
else
None
hidden_states
=
[]
if
output_hidden_states
else
None
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
# cache new mems
if
inputs
[
"
use_mems
"
]
:
if
use_mems
:
new_mems
=
new_mems
+
(
self
.
cache_mem
(
output_h
,
inputs
[
"
mems
"
]
[
i
]),)
new_mems
=
new_mems
+
(
self
.
cache_mem
(
output_h
,
mems
[
i
]),)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
outputs
=
layer_module
(
outputs
=
layer_module
(
...
@@ -770,34 +742,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -770,34 +742,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask
,
attn_mask
,
pos_emb
,
pos_emb
,
seg_mat
,
seg_mat
,
inputs
[
"
mems
"
]
[
i
],
mems
[
i
],
inputs
[
"
target_mapping
"
]
,
target_mapping
,
inputs
[
"
head_mask
"
]
[
i
],
head_mask
[
i
],
inputs
[
"
output_attentions
"
]
,
output_attentions
,
training
=
inputs
[
"
training
"
]
,
training
=
training
,
)
)
output_h
,
output_g
=
outputs
[:
2
]
output_h
,
output_g
=
outputs
[:
2
]
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
attentions
.
append
(
outputs
[
2
])
attentions
.
append
(
outputs
[
2
])
# Add last hidden state
# Add last hidden state
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
,
training
=
inputs
[
"
training
"
]
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
,
training
=
training
)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output
=
tf
.
transpose
(
output
,
perm
=
(
1
,
0
,
2
))
output
=
tf
.
transpose
(
output
,
perm
=
(
1
,
0
,
2
))
if
not
inputs
[
"
use_mems
"
]
:
if
not
use_mems
:
new_mems
=
None
new_mems
=
None
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
if
output_g
is
not
None
:
if
output_g
is
not
None
:
hidden_states
=
tuple
(
tf
.
transpose
(
h
,
perm
=
(
1
,
0
,
2
))
for
hs
in
hidden_states
for
h
in
hs
)
hidden_states
=
tuple
(
tf
.
transpose
(
h
,
perm
=
(
1
,
0
,
2
))
for
hs
in
hidden_states
for
h
in
hs
)
else
:
else
:
hidden_states
=
tuple
(
tf
.
transpose
(
hs
,
perm
=
(
1
,
0
,
2
))
for
hs
in
hidden_states
)
hidden_states
=
tuple
(
tf
.
transpose
(
hs
,
perm
=
(
1
,
0
,
2
))
for
hs
in
hidden_states
)
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
if
inputs
[
"
target_mapping
"
]
is
not
None
:
if
target_mapping
is
not
None
:
# when target_mapping is provided, there are 2-tuple of attentions
# when target_mapping is provided, there are 2-tuple of attentions
attentions
=
tuple
(
attentions
=
tuple
(
tuple
(
tf
.
transpose
(
attn_stream
,
perm
=
(
2
,
3
,
0
,
1
))
for
attn_stream
in
t
)
for
t
in
attentions
tuple
(
tf
.
transpose
(
attn_stream
,
perm
=
(
2
,
3
,
0
,
1
))
for
attn_stream
in
t
)
for
t
in
attentions
...
@@ -805,7 +777,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -805,7 +777,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else
:
else
:
attentions
=
tuple
(
tf
.
transpose
(
t
,
perm
=
(
2
,
3
,
0
,
1
))
for
t
in
attentions
)
attentions
=
tuple
(
tf
.
transpose
(
t
,
perm
=
(
2
,
3
,
0
,
1
))
for
t
in
attentions
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
output
,
new_mems
,
hidden_states
,
attentions
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
output
,
new_mems
,
hidden_states
,
attentions
]
if
v
is
not
None
)
return
TFXLNetModelOutput
(
return
TFXLNetModelOutput
(
...
@@ -1154,6 +1126,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
...
@@ -1154,6 +1126,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
"transformer"
)
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
"transformer"
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1179,9 +1152,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
...
@@ -1179,9 +1152,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
training
=
False
,
training
=
False
,
**
kwargs
,
**
kwargs
,
):
):
inputs
=
input_processing
(
outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1196,23 +1167,6 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
...
@@ -1196,23 +1167,6 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
return
outputs
return
outputs
...
@@ -1286,6 +1240,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1286,6 +1240,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
return
inputs
return
inputs
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
replace_return_docstrings
(
output_type
=
TFXLNetLMHeadModelOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
TFXLNetLMHeadModelOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
def
call
(
...
@@ -1349,9 +1304,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1349,9 +1304,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
... 0
... 0
>>> ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
>>> ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
```"""
```"""
inputs
=
input_processing
(
transformer_outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1365,34 +1318,16 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1365,34 +1318,16 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
transformer_outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
hidden_state
=
transformer_outputs
[
0
]
hidden_state
=
transformer_outputs
[
0
]
logits
=
self
.
lm_loss
(
hidden_state
,
training
=
inputs
[
"
training
"
]
)
logits
=
self
.
lm_loss
(
hidden_state
,
training
=
training
)
loss
=
None
loss
=
None
if
inputs
[
"
labels
"
]
is
not
None
:
if
labels
is
not
None
:
loss
=
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
logits
)
loss
=
self
.
hf_compute_loss
(
labels
,
logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
@@ -1432,6 +1367,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
...
@@ -1432,6 +1367,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"logits_proj"
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"logits_proj"
)
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1464,9 +1400,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
...
@@ -1464,9 +1400,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
"""
inputs
=
input_processing
(
transformer_outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1480,34 +1414,16 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
...
@@ -1480,34 +1414,16 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
transformer_outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
return_dict
,
training
=
inputs
[
"training"
],
)
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
logits
=
self
.
logits_proj
(
output
)
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
@@ -1558,6 +1474,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1558,6 +1474,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
"""
"""
return
{
"input_ids"
:
tf
.
constant
(
MULTIPLE_CHOICE_DUMMY_INPUTS
)}
return
{
"input_ids"
:
tf
.
constant
(
MULTIPLE_CHOICE_DUMMY_INPUTS
)}
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, num_choices, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, num_choices, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1590,72 +1507,45 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1590,72 +1507,45 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
"""
"""
inputs
=
input_processing
(
if
input_ids
is
not
None
:
func
=
self
.
call
,
num_choices
=
shape_list
(
input_ids
)[
1
]
config
=
self
.
config
,
seq_length
=
shape_list
(
input_ids
)[
2
]
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
token_type_ids
=
token_type_ids
,
input_mask
=
input_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
use_mems
=
use_mems
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"input_ids"
]
is
not
None
:
num_choices
=
shape_list
(
inputs
[
"input_ids"
])[
1
]
seq_length
=
shape_list
(
inputs
[
"input_ids"
])[
2
]
else
:
else
:
num_choices
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[
1
]
num_choices
=
shape_list
(
inputs_embeds
)[
1
]
seq_length
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[
2
]
seq_length
=
shape_list
(
inputs_embeds
)[
2
]
flat_input_ids
=
tf
.
reshape
(
inputs
[
"input_ids"
],
(
-
1
,
seq_length
))
if
inputs
[
"input_ids"
]
is
not
None
else
None
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
if
input_ids
is
not
None
else
None
flat_attention_mask
=
(
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
tf
.
reshape
(
inputs
[
"attention_mask"
],
(
-
1
,
seq_length
))
if
inputs
[
"attention_mask"
]
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
)
flat_input_mask
=
tf
.
reshape
(
input_mask
,
(
-
1
,
seq_length
))
if
input_mask
is
not
None
else
None
flat_token_type_ids
=
(
tf
.
reshape
(
inputs
[
"token_type_ids"
],
(
-
1
,
seq_length
))
if
inputs
[
"token_type_ids"
]
is
not
None
else
None
)
flat_input_mask
=
(
tf
.
reshape
(
inputs
[
"input_mask"
],
(
-
1
,
seq_length
))
if
inputs
[
"input_mask"
]
is
not
None
else
None
)
flat_inputs_embeds
=
(
flat_inputs_embeds
=
(
tf
.
reshape
(
inputs
[
"inputs
_embeds
"
]
,
(
-
1
,
seq_length
,
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[
3
]))
tf
.
reshape
(
inputs_embeds
,
(
-
1
,
seq_length
,
shape_list
(
inputs_embeds
)[
3
]))
if
inputs
[
"
inputs_embeds
"
]
is
not
None
if
inputs_embeds
is
not
None
else
None
else
None
)
)
transformer_outputs
=
self
.
transformer
(
transformer_outputs
=
self
.
transformer
(
flat_input_ids
,
flat_input_ids
,
flat_attention_mask
,
flat_attention_mask
,
inputs
[
"
mems
"
]
,
mems
,
inputs
[
"
perm_mask
"
]
,
perm_mask
,
inputs
[
"
target_mapping
"
]
,
target_mapping
,
flat_token_type_ids
,
flat_token_type_ids
,
flat_input_mask
,
flat_input_mask
,
inputs
[
"
head_mask
"
]
,
head_mask
,
flat_inputs_embeds
,
flat_inputs_embeds
,
inputs
[
"
use_mems
"
]
,
use_mems
,
inputs
[
"
output_attentions
"
]
,
output_attentions
,
inputs
[
"
output_hidden_states
"
]
,
output_hidden_states
,
return_dict
=
inputs
[
"
return_dict
"
]
,
return_dict
=
return_dict
,
training
=
inputs
[
"
training
"
]
,
training
=
training
,
)
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
sequence_summary
(
output
)
logits
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
logits
)
logits
=
self
.
logits_proj
(
logits
)
reshaped_logits
=
tf
.
reshape
(
logits
,
(
-
1
,
num_choices
))
reshaped_logits
=
tf
.
reshape
(
logits
,
(
-
1
,
num_choices
))
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
reshaped_logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
reshaped_logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
reshaped_logits
,)
+
transformer_outputs
[
1
:]
output
=
(
reshaped_logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
@@ -1706,6 +1596,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
...
@@ -1706,6 +1596,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"classifier"
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"classifier"
)
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1737,9 +1628,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
...
@@ -1737,9 +1628,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
"""
inputs
=
input_processing
(
transformer_outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1753,31 +1642,13 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
...
@@ -1753,31 +1642,13 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
transformer_outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
classifier
(
output
)
logits
=
self
.
classifier
(
output
)
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
@@ -1812,6 +1683,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1812,6 +1683,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"qa_outputs"
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"qa_outputs"
)
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1849,9 +1721,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1849,9 +1721,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
are not taken into account for computing the loss.
"""
"""
inputs
=
input_processing
(
transformer_outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1865,26 +1735,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1865,26 +1735,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
start_positions
=
start_positions
,
end_positions
=
end_positions
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
transformer_outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
sequence_output
=
transformer_outputs
[
0
]
sequence_output
=
transformer_outputs
[
0
]
...
@@ -1894,12 +1745,12 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1894,12 +1745,12 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
end_logits
=
tf
.
squeeze
(
end_logits
,
axis
=-
1
)
end_logits
=
tf
.
squeeze
(
end_logits
,
axis
=-
1
)
loss
=
None
loss
=
None
if
inputs
[
"
start_positions
"
]
is
not
None
and
inputs
[
"
end_positions
"
]
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
labels
=
{
"start_position"
:
inputs
[
"
start_positions
"
]
}
labels
=
{
"start_position"
:
start_positions
}
labels
[
"end_position"
]
=
inputs
[
"
end_positions
"
]
labels
[
"end_position"
]
=
end_positions
loss
=
self
.
hf_compute_loss
(
labels
,
(
start_logits
,
end_logits
))
loss
=
self
.
hf_compute_loss
(
labels
,
(
start_logits
,
end_logits
))
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
start_logits
,
end_logits
)
+
transformer_outputs
[
1
:]
output
=
(
start_logits
,
end_logits
)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment