Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
015de6f0
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "da7ea9a4e337eb2eed204090fe38198418c01134"
Unverified
Commit
015de6f0
authored
Mar 15, 2022
by
Kamal Raj
Committed by
GitHub
Mar 15, 2022
Browse files
TF clearer model variable naming: xlnet (#16150)
parent
a23a7c0c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
108 additions
and
257 deletions
+108
-257
src/transformers/models/xlnet/modeling_tf_xlnet.py
src/transformers/models/xlnet/modeling_tf_xlnet.py
+108
-257
No files found.
src/transformers/models/xlnet/modeling_tf_xlnet.py
View file @
015de6f0
...
@@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
...
@@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings
,
TFSharedEmbeddings
,
TFTokenClassificationLoss
,
TFTokenClassificationLoss
,
get_initializer
,
get_initializer
,
input_processing
,
keras_serializable
,
keras_serializable
,
unpack_inputs
,
)
)
from
...tf_utils
import
shape_list
from
...tf_utils
import
shape_list
from
...utils
import
logging
from
...utils
import
logging
...
@@ -578,6 +578,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -578,6 +578,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return
pos_emb
return
pos_emb
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -596,63 +597,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -596,63 +597,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
training
=
False
,
training
=
False
,
**
kwargs
,
**
kwargs
,
):
):
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
token_type_ids
=
token_type_ids
,
input_mask
=
input_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
use_mems
=
use_mems
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
training
and
inputs
[
"
use_mems
"
]
is
None
:
if
training
and
use_mems
is
None
:
inputs
[
"
use_mems
"
]
=
self
.
use_mems_train
use_mems
=
self
.
use_mems_train
else
:
else
:
inputs
[
"
use_mems
"
]
=
self
.
use_mems_eval
use_mems
=
self
.
use_mems_eval
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
# so we move here the first dimension (batch) to the end
if
inputs
[
"
input_ids
"
]
is
not
None
and
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
elif
input_ids
is
not
None
:
inputs
[
"
input_ids
"
]
=
tf
.
transpose
(
inputs
[
"
input_ids
"
]
,
perm
=
(
1
,
0
))
input_ids
=
tf
.
transpose
(
input_ids
,
perm
=
(
1
,
0
))
qlen
,
bsz
=
shape_list
(
inputs
[
"
input_ids
"
]
)[:
2
]
qlen
,
bsz
=
shape_list
(
input_ids
)[:
2
]
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
elif
inputs_embeds
is
not
None
:
inputs
[
"
inputs_embeds
"
]
=
tf
.
transpose
(
inputs
[
"inputs
_embeds
"
]
,
perm
=
(
1
,
0
,
2
))
inputs_embeds
=
tf
.
transpose
(
inputs_embeds
,
perm
=
(
1
,
0
,
2
))
qlen
,
bsz
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
2
]
qlen
,
bsz
=
shape_list
(
inputs_embeds
)[:
2
]
else
:
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
inputs
[
"token_type_ids"
]
=
(
token_type_ids
=
tf
.
transpose
(
token_type_ids
,
perm
=
(
1
,
0
))
if
token_type_ids
is
not
None
else
None
tf
.
transpose
(
inputs
[
"token_type_ids"
],
perm
=
(
1
,
0
))
if
inputs
[
"token_type_ids"
]
is
not
None
else
None
input_mask
=
tf
.
transpose
(
input_mask
,
perm
=
(
1
,
0
))
if
input_mask
is
not
None
else
None
)
attention_mask
=
tf
.
transpose
(
attention_mask
,
perm
=
(
1
,
0
))
if
attention_mask
is
not
None
else
None
inputs
[
"input_mask"
]
=
(
perm_mask
=
tf
.
transpose
(
perm_mask
,
perm
=
(
1
,
2
,
0
))
if
perm_mask
is
not
None
else
None
tf
.
transpose
(
inputs
[
"input_mask"
],
perm
=
(
1
,
0
))
if
inputs
[
"input_mask"
]
is
not
None
else
None
target_mapping
=
tf
.
transpose
(
target_mapping
,
perm
=
(
1
,
2
,
0
))
if
target_mapping
is
not
None
else
None
)
inputs
[
"attention_mask"
]
=
(
tf
.
transpose
(
inputs
[
"attention_mask"
],
perm
=
(
1
,
0
))
if
inputs
[
"attention_mask"
]
is
not
None
else
None
)
inputs
[
"perm_mask"
]
=
(
tf
.
transpose
(
inputs
[
"perm_mask"
],
perm
=
(
1
,
2
,
0
))
if
inputs
[
"perm_mask"
]
is
not
None
else
None
)
inputs
[
"target_mapping"
]
=
(
tf
.
transpose
(
inputs
[
"target_mapping"
],
perm
=
(
1
,
2
,
0
))
if
inputs
[
"target_mapping"
]
is
not
None
else
None
)
mlen
=
shape_list
(
inputs
[
"
mems
"
]
[
0
])[
0
]
if
inputs
[
"
mems
"
]
is
not
None
and
inputs
[
"
mems
"
]
[
0
]
is
not
None
else
0
mlen
=
shape_list
(
mems
[
0
])[
0
]
if
mems
is
not
None
and
mems
[
0
]
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
# Attention mask
# Attention mask
...
@@ -666,19 +638,19 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -666,19 +638,19 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise
ValueError
(
f
"Unsupported attention type:
{
self
.
attn_type
}
"
)
raise
ValueError
(
f
"Unsupported attention type:
{
self
.
attn_type
}
"
)
# data mask: input mask & perm mask
# data mask: input mask & perm mask
assert
inputs
[
"
input_mask
"
]
is
None
or
inputs
[
"
attention_mask
"
]
is
None
,
(
assert
input_mask
is
None
or
attention_mask
is
None
,
(
"You can only use one of input_mask (uses 1 for padding) "
"You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
)
)
if
inputs
[
"
input_mask
"
]
is
None
and
inputs
[
"
attention_mask
"
]
is
not
None
:
if
input_mask
is
None
and
attention_mask
is
not
None
:
one_cst
=
tf
.
constant
(
1.0
)
one_cst
=
tf
.
constant
(
1.0
)
inputs
[
"
input_mask
"
]
=
1.0
-
tf
.
cast
(
inputs
[
"
attention_mask
"
]
,
dtype
=
one_cst
.
dtype
)
input_mask
=
1.0
-
tf
.
cast
(
attention_mask
,
dtype
=
one_cst
.
dtype
)
if
inputs
[
"
input_mask
"
]
is
not
None
and
inputs
[
"
perm_mask
"
]
is
not
None
:
if
input_mask
is
not
None
and
perm_mask
is
not
None
:
data_mask
=
inputs
[
"
input_mask
"
]
[
None
]
+
inputs
[
"
perm_mask
"
]
data_mask
=
input_mask
[
None
]
+
perm_mask
elif
inputs
[
"
input_mask
"
]
is
not
None
and
inputs
[
"
perm_mask
"
]
is
None
:
elif
input_mask
is
not
None
and
perm_mask
is
None
:
data_mask
=
inputs
[
"
input_mask
"
]
[
None
]
data_mask
=
input_mask
[
None
]
elif
inputs
[
"
input_mask
"
]
is
None
and
inputs
[
"
perm_mask
"
]
is
not
None
:
elif
input_mask
is
None
and
perm_mask
is
not
None
:
data_mask
=
inputs
[
"
perm_mask
"
]
data_mask
=
perm_mask
else
:
else
:
data_mask
=
None
data_mask
=
None
...
@@ -704,33 +676,33 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -704,33 +676,33 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
non_tgt_mask
=
None
non_tgt_mask
=
None
# Word embeddings and prepare h & g hidden states
# Word embeddings and prepare h & g hidden states
if
inputs
[
"
inputs_embeds
"
]
is
not
None
:
if
inputs_embeds
is
not
None
:
word_emb_k
=
inputs
[
"
inputs_embeds
"
]
word_emb_k
=
inputs_embeds
else
:
else
:
word_emb_k
=
self
.
word_embedding
(
inputs
[
"
input_ids
"
]
)
word_emb_k
=
self
.
word_embedding
(
input_ids
)
output_h
=
self
.
dropout
(
word_emb_k
,
training
=
inputs
[
"
training
"
]
)
output_h
=
self
.
dropout
(
word_emb_k
,
training
=
training
)
if
inputs
[
"
target_mapping
"
]
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
tf
.
tile
(
self
.
mask_emb
,
[
shape_list
(
inputs
[
"
target_mapping
"
]
)[
0
],
bsz
,
1
])
word_emb_q
=
tf
.
tile
(
self
.
mask_emb
,
[
shape_list
(
target_mapping
)[
0
],
bsz
,
1
])
# else: # We removed the inp_q input which was same as target mapping
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g
=
self
.
dropout
(
word_emb_q
,
training
=
inputs
[
"
training
"
]
)
output_g
=
self
.
dropout
(
word_emb_q
,
training
=
training
)
else
:
else
:
output_g
=
None
output_g
=
None
# Segment embedding
# Segment embedding
if
inputs
[
"
token_type_ids
"
]
is
not
None
:
if
token_type_ids
is
not
None
:
# Convert `token_type_ids` to one-hot `seg_mat`
# Convert `token_type_ids` to one-hot `seg_mat`
if
mlen
>
0
:
if
mlen
>
0
:
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
inputs
[
"
token_type_ids
"
]
.
dtype
)
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
token_type_ids
.
dtype
)
cat_ids
=
tf
.
concat
([
mem_pad
,
inputs
[
"
token_type_ids
"
]
],
0
)
cat_ids
=
tf
.
concat
([
mem_pad
,
token_type_ids
],
0
)
else
:
else
:
cat_ids
=
inputs
[
"
token_type_ids
"
]
cat_ids
=
token_type_ids
# `1` indicates not in the same segment [qlen x klen x bsz]
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
tf
.
cast
(
seg_mat
=
tf
.
cast
(
tf
.
logical_not
(
tf
.
equal
(
inputs
[
"
token_type_ids
"
]
[:,
None
],
cat_ids
[
None
,
:])),
tf
.
logical_not
(
tf
.
equal
(
token_type_ids
[:,
None
],
cat_ids
[
None
,
:])),
dtype
=
inputs
[
"
token_type_ids
"
]
.
dtype
,
dtype
=
token_type_ids
.
dtype
,
)
)
seg_mat
=
tf
.
one_hot
(
seg_mat
,
2
)
seg_mat
=
tf
.
one_hot
(
seg_mat
,
2
)
else
:
else
:
...
@@ -738,29 +710,29 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -738,29 +710,29 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Positional encoding
# Positional encoding
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
,
training
=
inputs
[
"
training
"
]
)
pos_emb
=
self
.
dropout
(
pos_emb
,
training
=
training
)
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if
inputs
[
"
head_mask
"
]
is
not
None
:
if
head_mask
is
not
None
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
inputs
[
"
head_mask
"
]
=
[
None
]
*
self
.
n_layer
head_mask
=
[
None
]
*
self
.
n_layer
new_mems
=
()
new_mems
=
()
if
inputs
[
"
mems
"
]
is
None
:
if
mems
is
None
:
inputs
[
"
mems
"
]
=
[
None
]
*
len
(
self
.
layer
)
mems
=
[
None
]
*
len
(
self
.
layer
)
attentions
=
[]
if
inputs
[
"
output_attentions
"
]
else
None
attentions
=
[]
if
output_attentions
else
None
hidden_states
=
[]
if
inputs
[
"
output_hidden_states
"
]
else
None
hidden_states
=
[]
if
output_hidden_states
else
None
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
# cache new mems
if
inputs
[
"
use_mems
"
]
:
if
use_mems
:
new_mems
=
new_mems
+
(
self
.
cache_mem
(
output_h
,
inputs
[
"
mems
"
]
[
i
]),)
new_mems
=
new_mems
+
(
self
.
cache_mem
(
output_h
,
mems
[
i
]),)
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
outputs
=
layer_module
(
outputs
=
layer_module
(
...
@@ -770,34 +742,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -770,34 +742,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask
,
attn_mask
,
pos_emb
,
pos_emb
,
seg_mat
,
seg_mat
,
inputs
[
"
mems
"
]
[
i
],
mems
[
i
],
inputs
[
"
target_mapping
"
]
,
target_mapping
,
inputs
[
"
head_mask
"
]
[
i
],
head_mask
[
i
],
inputs
[
"
output_attentions
"
]
,
output_attentions
,
training
=
inputs
[
"
training
"
]
,
training
=
training
,
)
)
output_h
,
output_g
=
outputs
[:
2
]
output_h
,
output_g
=
outputs
[:
2
]
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
attentions
.
append
(
outputs
[
2
])
attentions
.
append
(
outputs
[
2
])
# Add last hidden state
# Add last hidden state
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
,
training
=
inputs
[
"
training
"
]
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
,
training
=
training
)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output
=
tf
.
transpose
(
output
,
perm
=
(
1
,
0
,
2
))
output
=
tf
.
transpose
(
output
,
perm
=
(
1
,
0
,
2
))
if
not
inputs
[
"
use_mems
"
]
:
if
not
use_mems
:
new_mems
=
None
new_mems
=
None
if
inputs
[
"
output_hidden_states
"
]
:
if
output_hidden_states
:
if
output_g
is
not
None
:
if
output_g
is
not
None
:
hidden_states
=
tuple
(
tf
.
transpose
(
h
,
perm
=
(
1
,
0
,
2
))
for
hs
in
hidden_states
for
h
in
hs
)
hidden_states
=
tuple
(
tf
.
transpose
(
h
,
perm
=
(
1
,
0
,
2
))
for
hs
in
hidden_states
for
h
in
hs
)
else
:
else
:
hidden_states
=
tuple
(
tf
.
transpose
(
hs
,
perm
=
(
1
,
0
,
2
))
for
hs
in
hidden_states
)
hidden_states
=
tuple
(
tf
.
transpose
(
hs
,
perm
=
(
1
,
0
,
2
))
for
hs
in
hidden_states
)
if
inputs
[
"
output_attentions
"
]
:
if
output_attentions
:
if
inputs
[
"
target_mapping
"
]
is
not
None
:
if
target_mapping
is
not
None
:
# when target_mapping is provided, there are 2-tuple of attentions
# when target_mapping is provided, there are 2-tuple of attentions
attentions
=
tuple
(
attentions
=
tuple
(
tuple
(
tf
.
transpose
(
attn_stream
,
perm
=
(
2
,
3
,
0
,
1
))
for
attn_stream
in
t
)
for
t
in
attentions
tuple
(
tf
.
transpose
(
attn_stream
,
perm
=
(
2
,
3
,
0
,
1
))
for
attn_stream
in
t
)
for
t
in
attentions
...
@@ -805,7 +777,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -805,7 +777,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else
:
else
:
attentions
=
tuple
(
tf
.
transpose
(
t
,
perm
=
(
2
,
3
,
0
,
1
))
for
t
in
attentions
)
attentions
=
tuple
(
tf
.
transpose
(
t
,
perm
=
(
2
,
3
,
0
,
1
))
for
t
in
attentions
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
output
,
new_mems
,
hidden_states
,
attentions
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
output
,
new_mems
,
hidden_states
,
attentions
]
if
v
is
not
None
)
return
TFXLNetModelOutput
(
return
TFXLNetModelOutput
(
...
@@ -1154,6 +1126,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
...
@@ -1154,6 +1126,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
"transformer"
)
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
"transformer"
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1179,9 +1152,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
...
@@ -1179,9 +1152,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
training
=
False
,
training
=
False
,
**
kwargs
,
**
kwargs
,
):
):
inputs
=
input_processing
(
outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1196,23 +1167,6 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
...
@@ -1196,23 +1167,6 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
return
outputs
return
outputs
...
@@ -1286,6 +1240,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1286,6 +1240,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
return
inputs
return
inputs
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
replace_return_docstrings
(
output_type
=
TFXLNetLMHeadModelOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
TFXLNetLMHeadModelOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
def
call
(
...
@@ -1349,9 +1304,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1349,9 +1304,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
... 0
... 0
>>> ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
>>> ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
```"""
```"""
inputs
=
input_processing
(
transformer_outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1365,34 +1318,16 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1365,34 +1318,16 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
transformer_outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
hidden_state
=
transformer_outputs
[
0
]
hidden_state
=
transformer_outputs
[
0
]
logits
=
self
.
lm_loss
(
hidden_state
,
training
=
inputs
[
"
training
"
]
)
logits
=
self
.
lm_loss
(
hidden_state
,
training
=
training
)
loss
=
None
loss
=
None
if
inputs
[
"
labels
"
]
is
not
None
:
if
labels
is
not
None
:
loss
=
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
logits
)
loss
=
self
.
hf_compute_loss
(
labels
,
logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
@@ -1432,6 +1367,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
...
@@ -1432,6 +1367,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"logits_proj"
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"logits_proj"
)
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1464,9 +1400,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
...
@@ -1464,9 +1400,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
"""
inputs
=
input_processing
(
transformer_outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1480,34 +1414,16 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
...
@@ -1480,34 +1414,16 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
transformer_outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
return_dict
,
training
=
inputs
[
"training"
],
)
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
logits
=
self
.
logits_proj
(
output
)
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
@@ -1558,6 +1474,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1558,6 +1474,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
"""
"""
return
{
"input_ids"
:
tf
.
constant
(
MULTIPLE_CHOICE_DUMMY_INPUTS
)}
return
{
"input_ids"
:
tf
.
constant
(
MULTIPLE_CHOICE_DUMMY_INPUTS
)}
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, num_choices, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, num_choices, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1590,72 +1507,45 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
...
@@ -1590,72 +1507,45 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
"""
"""
inputs
=
input_processing
(
if
input_ids
is
not
None
:
func
=
self
.
call
,
num_choices
=
shape_list
(
input_ids
)[
1
]
config
=
self
.
config
,
seq_length
=
shape_list
(
input_ids
)[
2
]
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
token_type_ids
=
token_type_ids
,
input_mask
=
input_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
use_mems
=
use_mems
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"input_ids"
]
is
not
None
:
num_choices
=
shape_list
(
inputs
[
"input_ids"
])[
1
]
seq_length
=
shape_list
(
inputs
[
"input_ids"
])[
2
]
else
:
else
:
num_choices
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[
1
]
num_choices
=
shape_list
(
inputs_embeds
)[
1
]
seq_length
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[
2
]
seq_length
=
shape_list
(
inputs_embeds
)[
2
]
flat_input_ids
=
tf
.
reshape
(
inputs
[
"input_ids"
],
(
-
1
,
seq_length
))
if
inputs
[
"input_ids"
]
is
not
None
else
None
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
if
input_ids
is
not
None
else
None
flat_attention_mask
=
(
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
tf
.
reshape
(
inputs
[
"attention_mask"
],
(
-
1
,
seq_length
))
if
inputs
[
"attention_mask"
]
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
)
flat_input_mask
=
tf
.
reshape
(
input_mask
,
(
-
1
,
seq_length
))
if
input_mask
is
not
None
else
None
flat_token_type_ids
=
(
tf
.
reshape
(
inputs
[
"token_type_ids"
],
(
-
1
,
seq_length
))
if
inputs
[
"token_type_ids"
]
is
not
None
else
None
)
flat_input_mask
=
(
tf
.
reshape
(
inputs
[
"input_mask"
],
(
-
1
,
seq_length
))
if
inputs
[
"input_mask"
]
is
not
None
else
None
)
flat_inputs_embeds
=
(
flat_inputs_embeds
=
(
tf
.
reshape
(
inputs
[
"inputs
_embeds
"
]
,
(
-
1
,
seq_length
,
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[
3
]))
tf
.
reshape
(
inputs_embeds
,
(
-
1
,
seq_length
,
shape_list
(
inputs_embeds
)[
3
]))
if
inputs
[
"
inputs_embeds
"
]
is
not
None
if
inputs_embeds
is
not
None
else
None
else
None
)
)
transformer_outputs
=
self
.
transformer
(
transformer_outputs
=
self
.
transformer
(
flat_input_ids
,
flat_input_ids
,
flat_attention_mask
,
flat_attention_mask
,
inputs
[
"
mems
"
]
,
mems
,
inputs
[
"
perm_mask
"
]
,
perm_mask
,
inputs
[
"
target_mapping
"
]
,
target_mapping
,
flat_token_type_ids
,
flat_token_type_ids
,
flat_input_mask
,
flat_input_mask
,
inputs
[
"
head_mask
"
]
,
head_mask
,
flat_inputs_embeds
,
flat_inputs_embeds
,
inputs
[
"
use_mems
"
]
,
use_mems
,
inputs
[
"
output_attentions
"
]
,
output_attentions
,
inputs
[
"
output_hidden_states
"
]
,
output_hidden_states
,
return_dict
=
inputs
[
"
return_dict
"
]
,
return_dict
=
return_dict
,
training
=
inputs
[
"
training
"
]
,
training
=
training
,
)
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
sequence_summary
(
output
)
logits
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
logits
)
logits
=
self
.
logits_proj
(
logits
)
reshaped_logits
=
tf
.
reshape
(
logits
,
(
-
1
,
num_choices
))
reshaped_logits
=
tf
.
reshape
(
logits
,
(
-
1
,
num_choices
))
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
reshaped_logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
reshaped_logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
reshaped_logits
,)
+
transformer_outputs
[
1
:]
output
=
(
reshaped_logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
@@ -1706,6 +1596,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
...
@@ -1706,6 +1596,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"classifier"
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"classifier"
)
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1737,9 +1628,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
...
@@ -1737,9 +1628,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
"""
inputs
=
input_processing
(
transformer_outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1753,31 +1642,13 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
...
@@ -1753,31 +1642,13 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
transformer_outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
classifier
(
output
)
logits
=
self
.
classifier
(
output
)
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
inputs
[
"
labels
"
]
,
logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
output
=
(
logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
@@ -1812,6 +1683,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1812,6 +1683,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"qa_outputs"
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"qa_outputs"
)
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
XLNET_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
processor_class
=
_TOKENIZER_FOR_DOC
,
...
@@ -1849,9 +1721,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1849,9 +1721,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
are not taken into account for computing the loss.
"""
"""
inputs
=
input_processing
(
transformer_outputs
=
self
.
transformer
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
mems
=
mems
,
...
@@ -1865,26 +1735,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1865,26 +1735,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
start_positions
=
start_positions
,
end_positions
=
end_positions
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
transformer_outputs
=
self
.
transformer
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
mems
=
inputs
[
"mems"
],
perm_mask
=
inputs
[
"perm_mask"
],
target_mapping
=
inputs
[
"target_mapping"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
input_mask
=
inputs
[
"input_mask"
],
head_mask
=
inputs
[
"head_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
use_mems
=
inputs
[
"use_mems"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
sequence_output
=
transformer_outputs
[
0
]
sequence_output
=
transformer_outputs
[
0
]
...
@@ -1894,12 +1745,12 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
...
@@ -1894,12 +1745,12 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
end_logits
=
tf
.
squeeze
(
end_logits
,
axis
=-
1
)
end_logits
=
tf
.
squeeze
(
end_logits
,
axis
=-
1
)
loss
=
None
loss
=
None
if
inputs
[
"
start_positions
"
]
is
not
None
and
inputs
[
"
end_positions
"
]
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
labels
=
{
"start_position"
:
inputs
[
"
start_positions
"
]
}
labels
=
{
"start_position"
:
start_positions
}
labels
[
"end_position"
]
=
inputs
[
"
end_positions
"
]
labels
[
"end_position"
]
=
end_positions
loss
=
self
.
hf_compute_loss
(
labels
,
(
start_logits
,
end_logits
))
loss
=
self
.
hf_compute_loss
(
labels
,
(
start_logits
,
end_logits
))
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
start_logits
,
end_logits
)
+
transformer_outputs
[
1
:]
output
=
(
start_logits
,
end_logits
)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment