Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c9591f6f
Commit
c9591f6f
authored
Sep 23, 2019
by
thomwolf
Browse files
updated models input format + tests
parent
c014d1f0
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
246 additions
and
238 deletions
+246
-238
pytorch_transformers/modeling_tf_bert.py
pytorch_transformers/modeling_tf_bert.py
+44
-46
pytorch_transformers/modeling_tf_distilbert.py
pytorch_transformers/modeling_tf_distilbert.py
+24
-25
pytorch_transformers/modeling_tf_gpt2.py
pytorch_transformers/modeling_tf_gpt2.py
+36
-38
pytorch_transformers/modeling_tf_openai.py
pytorch_transformers/modeling_tf_openai.py
+32
-34
pytorch_transformers/modeling_tf_roberta.py
pytorch_transformers/modeling_tf_roberta.py
+13
-12
pytorch_transformers/modeling_tf_transfo_xl.py
pytorch_transformers/modeling_tf_transfo_xl.py
+22
-24
pytorch_transformers/modeling_tf_xlm.py
pytorch_transformers/modeling_tf_xlm.py
+29
-29
pytorch_transformers/modeling_tf_xlnet.py
pytorch_transformers/modeling_tf_xlnet.py
+28
-29
pytorch_transformers/tests/modeling_tf_bert_test.py
pytorch_transformers/tests/modeling_tf_bert_test.py
+1
-1
pytorch_transformers/tests/modeling_tf_common_test.py
pytorch_transformers/tests/modeling_tf_common_test.py
+17
-0
No files found.
pytorch_transformers/modeling_tf_bert.py
View file @
c9591f6f
...
...
@@ -456,24 +456,23 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# def call(self, input_ids, attention_mask=None, token_type_ids=None,
# position_ids=None, head_mask=None, training=False):
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
attention_mask
,
head_mask
,
position_ids
,
token_type_ids
=
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
position_ids
=
inputs
.
get
(
'position_ids'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
else
:
input_ids
=
inputs
if
attention_mask
is
None
:
attention_mask
=
tf
.
fill
(
tf
.
shape
(
input_ids
),
1
)
...
...
@@ -637,8 +636,8 @@ class TFBertModel(TFBertPreTrainedModel):
super
(
TFBertModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
return
outputs
...
...
@@ -676,11 +675,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self
.
nsp
=
TFBertNSPHead
(
config
,
name
=
'nsp___cls'
)
self
.
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'mlm___cls'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
=
self
.
mlm
(
sequence_output
,
training
=
training
)
prediction_scores
=
self
.
mlm
(
sequence_output
,
training
=
kwargs
.
get
(
'training'
,
False
)
)
seq_relationship_score
=
self
.
nsp
(
pooled_output
)
outputs
=
(
prediction_scores
,
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
...
@@ -718,11 +717,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'mlm___cls'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
mlm
(
sequence_output
,
training
=
training
)
prediction_scores
=
self
.
mlm
(
sequence_output
,
training
=
kwargs
.
get
(
'training'
,
False
)
)
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
...
...
@@ -761,8 +760,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
nsp
=
TFBertNSPHead
(
config
,
name
=
'nsp___cls'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
pooled_output
=
outputs
[
1
]
seq_relationship_score
=
self
.
nsp
(
pooled_output
)
...
...
@@ -805,12 +804,12 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
'classifier'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
,
training
=
training
)
pooled_output
=
self
.
dropout
(
pooled_output
,
training
=
kwargs
.
get
(
'training'
,
False
)
)
logits
=
self
.
classifier
(
pooled_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
...
@@ -852,24 +851,23 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
1
,
name
=
'classifier'
)
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
attention_mask
,
head_mask
,
position_ids
,
token_type_ids
=
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
position_ids
=
inputs
.
get
(
'position_ids'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
else
:
input_ids
=
inputs
num_choices
=
tf
.
shape
(
input_ids
)[
1
]
seq_length
=
tf
.
shape
(
input_ids
)[
2
]
...
...
@@ -927,12 +925,12 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
'classifier'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
,
training
=
training
)
sequence_output
=
self
.
dropout
(
sequence_output
,
training
=
kwargs
.
get
(
'training'
,
False
)
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
...
@@ -976,8 +974,8 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
qa_outputs
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
'qa_outputs'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
sequence_output
=
outputs
[
0
]
...
...
pytorch_transformers/modeling_tf_distilbert.py
View file @
c9591f6f
...
...
@@ -418,20 +418,19 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def
_prune_heads
(
self
,
heads_to_prune
):
raise
NotImplementedError
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
(
attention_mask
,
head_mask
)
=
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
head_mask
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
else
:
input_ids
=
inputs
if
attention_mask
is
None
:
attention_mask
=
tf
.
ones
(
shape_list
(
input_ids
))
# (bs, seq_length)
...
...
@@ -532,8 +531,8 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
super
(
TFDistilBertModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
distilbert
=
TFDistilBertMainLayer
(
config
,
name
=
"distilbert"
)
# Embeddings
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
distilbert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
distilbert
(
inputs
,
**
kwargs
)
return
outputs
...
...
@@ -603,18 +602,17 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
self
.
vocab_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
"vocab_layer_norm"
)
self
.
vocab_projector
=
TFDistilBertLMHead
(
config
,
self
.
distilbert
.
embeddings
,
name
=
"vocab_projector"
)
def
call
(
self
,
inputs
,
training
=
False
):
dlbrt_output
=
self
.
distilbert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
d
isti
lb
e
rt_output
=
self
.
distilbert
(
inputs
,
**
kwargs
)
hidden_states
=
dlbrt_output
[
0
]
# (bs, seq_length, dim)
hidden_states
=
d
isti
lb
e
rt_output
[
0
]
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_transform
(
hidden_states
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
act
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_layer_norm
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_projector
(
prediction_logits
)
outputs
=
(
prediction_logits
,
)
+
dlbrt_output
[
1
:]
return
outputs
# prediction_logits, (all hidden_states), (all attentions)
outputs
=
(
prediction_logits
,)
+
distilbert_output
[
1
:]
return
outputs
# logits, (hidden_states), (attentions)
@
add_start_docstrings
(
"""DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
...
...
@@ -660,12 +658,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
"classifier"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
seq_classif_dropout
)
def
call
(
self
,
inputs
,
training
=
False
):
distilbert_output
=
self
.
distilbert
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
distilbert_output
=
self
.
distilbert
(
inputs
,
**
kwargs
)
hidden_state
=
distilbert_output
[
0
]
# (bs, seq_len, dim)
pooled_output
=
hidden_state
[:,
0
]
# (bs, dim)
pooled_output
=
self
.
pre_classifier
(
pooled_output
)
# (bs, dim)
pooled_output
=
self
.
dropout
(
pooled_output
,
training
=
training
)
# (bs, dim)
pooled_output
=
self
.
dropout
(
pooled_output
,
training
=
kwargs
.
get
(
'training'
,
False
)
)
# (bs, dim)
logits
=
self
.
classifier
(
pooled_output
)
# (bs, dim)
outputs
=
(
logits
,)
+
distilbert_output
[
1
:]
...
...
@@ -720,11 +719,11 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
assert
config
.
num_labels
==
2
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
qa_dropout
)
def
call
(
self
,
inputs
,
training
=
False
):
distilbert_output
=
self
.
distilbert
(
inputs
,
training
=
training
)
hidden_states
=
distilbert_output
[
0
]
# (bs, max_query_len, dim)
def
call
(
self
,
inputs
,
**
kwargs
):
distilbert_output
=
self
.
distilbert
(
inputs
,
**
kwargs
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
# (bs, max_query_len, dim)
hidden_states
=
distilbert_output
[
0
]
# (bs, max_query_len, dim)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
kwargs
.
get
(
'training'
,
False
))
# (bs, max_query_len, dim)
logits
=
self
.
qa_outputs
(
hidden_states
)
# (bs, max_query_len, 2)
start_logits
,
end_logits
=
tf
.
split
(
logits
,
2
,
axis
=-
1
)
start_logits
=
tf
.
squeeze
(
start_logits
,
axis
=-
1
)
...
...
pytorch_transformers/modeling_tf_gpt2.py
View file @
c9591f6f
...
...
@@ -230,26 +230,25 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
"""
raise
NotImplementedError
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
past
,
attention_mask
,
token_type_ids
,
position_ids
,
head_mask
=
None
,
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
None
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
past
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
attention_mask
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
head_mask
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
past
=
inputs
.
get
(
'past'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
position_ids
=
inputs
.
get
(
'position_ids'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
past
=
inputs
.
get
(
'past'
,
past
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
else
:
input_ids
=
inputs
if
past
is
None
:
past_length
=
0
...
...
@@ -442,8 +441,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
super
(
TFGPT2Model
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFGPT2MainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
return
outputs
...
...
@@ -483,8 +482,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
super
(
TFGPT2LMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFGPT2MainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
training
=
False
):
transformer_outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
transformer
.
wte
(
hidden_states
,
mode
=
"linear"
)
...
...
@@ -551,28 +550,27 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
self
.
transformer
=
TFGPT2MainLayer
(
config
,
name
=
'transformer'
)
self
.
multiple_choice_head
=
TFSequenceSummary
(
config
,
name
=
'multiple_choice_head'
)
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
mc_token_ids
,
past
,
attention_mask
,
token_type_ids
,
position_ids
,
head_mask
=
None
,
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
mc_token_ids
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
mc_token_ids
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
past
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
attention_mask
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
token_type
_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
position_ids
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
None
head_mask
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
None
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
past
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
attention_mask
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
position
_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
head_mask
mc_token_ids
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
mc_token_ids
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
mc_token_ids
=
inputs
.
get
(
'
mc_token_ids'
,
None
)
p
as
t
=
inputs
.
get
(
'
past'
,
None
)
attention_mask
=
inputs
.
get
(
'
attention_mask'
,
None
)
token_type
_ids
=
inputs
.
get
(
'
token_type_ids'
,
None
)
position_ids
=
inputs
.
get
(
'
position_ids'
,
None
)
head_mask
=
inputs
.
get
(
'
head_mask'
,
None
)
past
=
inputs
.
get
(
'
past'
,
past
)
attention_m
as
k
=
inputs
.
get
(
'
attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'
token_type_ids'
,
token_type_ids
)
position
_ids
=
inputs
.
get
(
'
position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'
head_mask'
,
head_mask
)
mc_token_ids
=
inputs
.
get
(
'
mc_token_ids'
,
mc_token_ids
)
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
else
:
input_ids
=
inputs
input_shapes
=
shape_list
(
input_ids
)
...
...
pytorch_transformers/modeling_tf_openai.py
View file @
c9591f6f
...
...
@@ -229,24 +229,23 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
"""
raise
NotImplementedError
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
attention_mask
,
token_type_ids
,
position_ids
,
head_mask
=
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
position_ids
=
inputs
.
get
(
'position_ids'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
else
:
input_ids
=
inputs
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
shape_list
(
input_ids
)[
-
1
],
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
...
...
@@ -420,8 +419,8 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
super
(
TFOpenAIGPTModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFOpenAIGPTMainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
return
outputs
...
...
@@ -455,8 +454,8 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
super
(
TFOpenAIGPTLMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFOpenAIGPTMainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
training
=
False
):
transformer_outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
transformer
.
tokens_embed
(
hidden_states
,
mode
=
"linear"
)
...
...
@@ -511,26 +510,25 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
self
.
transformer
=
TFOpenAIGPTMainLayer
(
config
,
name
=
'transformer'
)
self
.
multiple_choice_head
=
TFSequenceSummary
(
config
,
name
=
'multiple_choice_head'
)
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
mc_token_ids
,
attention_mask
,
token_type_ids
,
position_ids
,
head_mask
=
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
mc_token_ids
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
mc_token_ids
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
token_type
_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
None
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
position
_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
mc_token_ids
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
mc_token_ids
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
mc_token_ids
=
inputs
.
get
(
'
mc_token_ids'
,
None
)
attention_mask
=
inputs
.
get
(
'
attention_mask'
,
None
)
token_type
_ids
=
inputs
.
get
(
'
token_type_ids'
,
None
)
position_ids
=
inputs
.
get
(
'
position_ids'
,
None
)
head_mask
=
inputs
.
get
(
'
head_mask'
,
None
)
attention_mask
=
inputs
.
get
(
'
attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'
token_type_ids'
,
token_type_ids
)
position
_ids
=
inputs
.
get
(
'
position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'
head_mask'
,
head_mask
)
mc_token_ids
=
inputs
.
get
(
'
mc_token_ids'
,
mc_token_ids
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
else
:
input_ids
=
inputs
input_shapes
=
shape_list
(
input_ids
)
...
...
pytorch_transformers/modeling_tf_roberta.py
View file @
c9591f6f
...
...
@@ -73,21 +73,21 @@ class TFRobertaMainLayer(TFBertMainLayer):
super
(
TFRobertaMainLayer
,
self
).
__init__
(
config
,
**
kwargs
)
self
.
embeddings
=
TFRobertaEmbeddings
(
config
,
name
=
'embeddings'
)
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
**
kwargs
):
# Check that input_ids starts with control token
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
else
:
input_ids
=
inputs
if
tf
.
not_equal
(
tf
.
reduce_sum
(
input_ids
[:,
0
]),
0
):
logger
.
warning
(
"A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding."
)
return
super
(
TFRobertaMainLayer
,
self
).
call
(
inputs
,
training
=
training
)
return
super
(
TFRobertaMainLayer
,
self
).
call
(
inputs
,
**
kwargs
)
class
TFRobertaPreTrainedModel
(
TFPreTrainedModel
):
...
...
@@ -203,8 +203,8 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
super
(
TFRobertaModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
roberta
=
TFRobertaMainLayer
(
config
,
name
=
'roberta'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
roberta
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
roberta
(
inputs
,
**
kwargs
)
return
outputs
...
...
@@ -277,8 +277,8 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
self
.
roberta
=
TFRobertaMainLayer
(
config
,
name
=
"roberta"
)
self
.
lm_head
=
TFRobertaLMHead
(
config
,
self
.
roberta
.
embeddings
,
name
=
"lm_head"
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
roberta
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
roberta
(
inputs
,
**
kwargs
)
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
lm_head
(
sequence_output
)
...
...
@@ -347,8 +347,9 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
self
.
roberta
=
TFRobertaMainLayer
(
config
,
name
=
"roberta"
)
self
.
classifier
=
TFRobertaClassificationHead
(
config
,
name
=
"classifier"
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
roberta
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
roberta
(
inputs
,
**
kwargs
)
sequence_output
=
outputs
[
0
]
logits
=
self
.
classifier
(
sequence_output
,
training
=
training
)
...
...
pytorch_transformers/modeling_tf_transfo_xl.py
View file @
c9591f6f
...
...
@@ -447,20 +447,19 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
return
new_mems
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
mems
,
head_mask
=
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
mems
=
None
,
head_mask
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
mems
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
mems
=
inputs
.
get
(
'mems'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
mems
=
inputs
.
get
(
'mems'
,
mems
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
else
:
input_ids
=
inputs
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
...
...
@@ -632,8 +631,8 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
super
(
TFTransfoXLModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFTransfoXLMainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
training
=
False
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
training
=
training
,
**
kwargs
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
return
outputs
...
...
@@ -694,22 +693,21 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
def
init_mems
(
self
,
data
):
return
self
.
transformer
.
init_mems
(
data
)
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
mems
,
head_mask
,
labels
=
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
mems
=
None
,
head_mask
=
None
,
labels
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
labels
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
mems
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
labels
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
labels
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
mems
=
inputs
.
get
(
'mems'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
labels
=
inputs
.
get
(
'labels'
,
None
)
mems
=
inputs
.
get
(
'mems'
,
mems
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
labels
=
inputs
.
get
(
'labels'
,
labels
)
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
else
:
input_ids
=
inputs
bsz
,
tgt_len
=
shape_list
(
input_ids
)[:
2
]
...
...
pytorch_transformers/modeling_tf_xlm.py
View file @
c9591f6f
...
...
@@ -294,31 +294,31 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
"""
raise
NotImplementedError
def
call
(
self
,
inputs
,
training
=
False
):
# removed: src_enc=None, src_len=None
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
(
attention_mask
,
langs
,
token_type_ids
,
position_ids
,
lengths
,
cache
,
head_mask
)
=
None
,
None
,
None
,
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
training
=
False
):
# removed: src_enc=None, src_len=None
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
langs
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
lengths
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
None
cache
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
Non
e
head_mask
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
None
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
langs
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
langs
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
lengths
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
lengths
cache
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
cach
e
head_mask
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
head_mask
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
langs
=
inputs
.
get
(
'langs'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
position_ids
=
inputs
.
get
(
'position_ids'
,
None
)
lengths
=
inputs
.
get
(
'lengths'
,
None
)
cache
=
inputs
.
get
(
'cache'
,
Non
e
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
langs
=
inputs
.
get
(
'langs'
,
langs
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
lengths
=
inputs
.
get
(
'lengths'
,
lengths
)
cache
=
inputs
.
get
(
'cache'
,
cach
e
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
else
:
input_ids
=
inputs
if
lengths
is
None
:
lengths
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
not_equal
(
input_ids
,
self
.
pad_index
),
dtype
=
tf
.
int32
),
axis
=
1
)
...
...
@@ -538,8 +538,8 @@ class TFXLMModel(TFXLMPreTrainedModel):
super
(
TFXLMModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFXLMMainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
return
outputs
...
...
@@ -619,8 +619,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
self
.
pred_layer
=
TFXLMPredLayer
(
config
,
self
.
transformer
.
embeddings
,
name
=
'pred_layer_._proj'
)
def
call
(
self
,
inputs
,
training
=
False
):
transformer_outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
output
=
transformer_outputs
[
0
]
outputs
=
self
.
pred_layer
(
output
)
...
...
@@ -670,8 +670,8 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
self
.
transformer
=
TFXLMMainLayer
(
config
,
name
=
'transformer'
)
self
.
sequence_summary
=
TFSequenceSummary
(
config
,
name
=
'sequence_summary'
)
def
call
(
self
,
inputs
,
training
=
False
):
transformer_outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
output
=
transformer_outputs
[
0
]
logits
=
self
.
sequence_summary
(
output
)
...
...
@@ -731,8 +731,8 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
self
.
transformer
=
TFXLMMainLayer
(
config
,
name
=
'transformer'
)
self
.
qa_outputs
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
'qa_outputs'
)
def
call
(
self
,
inputs
,
training
=
False
):
transformer_outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
sequence_output
=
transformer_outputs
[
0
]
...
...
pytorch_transformers/modeling_tf_xlnet.py
View file @
c9591f6f
...
...
@@ -489,31 +489,30 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return
pos_emb
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
(
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
token_type_ids
,
input_mask
,
head_mask
)
=
None
,
None
,
None
,
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
mems
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
perm_mask
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
target_mapping
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
token_type_ids
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
None
input_mask
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
None
head_mask
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
None
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
mems
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
mems
perm_mask
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
perm_mask
target_mapping
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
target_mapping
token_type_ids
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
token_type_ids
input_mask
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
input_mask
head_mask
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
head_mask
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
el
se
:
el
if
isinstance
(
inputs
,
dict
)
:
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
mems
=
inputs
.
get
(
'mems'
,
None
)
perm_mask
=
inputs
.
get
(
'perm_mask'
,
None
)
target_mapping
=
inputs
.
get
(
'target_mapping'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
input_mask
=
inputs
.
get
(
'input_mask'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
mems
=
inputs
.
get
(
'mems'
,
mems
)
perm_mask
=
inputs
.
get
(
'perm_mask'
,
perm_mask
)
target_mapping
=
inputs
.
get
(
'target_mapping'
,
target_mapping
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
input_mask
=
inputs
.
get
(
'input_mask'
,
input_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
else
:
input_ids
=
inputs
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
...
...
@@ -784,8 +783,8 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
super
(
TFXLNetModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
return
outputs
...
...
@@ -829,8 +828,8 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
'transformer'
)
self
.
lm_loss
=
TFXLNetLMHead
(
config
,
self
.
transformer
.
word_embedding
,
name
=
'lm_loss'
)
def
call
(
self
,
inputs
,
training
=
False
):
transformer_outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
hidden_state
=
transformer_outputs
[
0
]
logits
=
self
.
lm_loss
(
hidden_state
)
...
...
@@ -886,8 +885,8 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
self
.
sequence_summary
=
TFSequenceSummary
(
config
,
name
=
'sequence_summary'
)
self
.
logits_proj
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
'logits_proj'
)
def
call
(
self
,
inputs
,
training
=
False
):
transformer_outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
...
...
@@ -933,8 +932,8 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
'transformer'
)
self
.
qa_outputs
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
'qa_outputs'
)
def
call
(
self
,
inputs
,
training
=
False
):
transformer_outputs
=
self
.
transformer
(
inputs
,
training
=
training
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
sequence_output
=
transformer_outputs
[
0
]
...
...
pytorch_transformers/tests/modeling_tf_bert_test.py
View file @
c9591f6f
...
...
@@ -138,7 +138,7 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
sequence_output
,
pooled_output
=
model
(
input
_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_id
s
)
sequence_output
,
pooled_output
=
model
(
inputs
)
inputs
=
[
input_ids
,
input_mask
]
sequence_output
,
pooled_output
=
model
(
inputs
)
...
...
pytorch_transformers/tests/modeling_tf_common_test.py
View file @
c9591f6f
...
...
@@ -29,6 +29,7 @@ from pytorch_transformers import is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
import
numpy
as
np
from
pytorch_transformers
import
TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else
:
...
...
@@ -65,6 +66,22 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def
test_keyword_and_dict_args
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
outputs_dict
=
model
(
inputs_dict
)
inputs_keywords
=
copy
.
deepcopy
(
inputs_dict
)
input_ids
=
inputs_keywords
.
pop
(
'input_ids'
)
outputs_keywords
=
model
(
input_ids
,
**
inputs_keywords
)
output_dict
=
outputs_dict
[
0
].
numpy
()
output_keywords
=
outputs_keywords
[
0
].
numpy
()
self
.
assertLess
(
np
.
sum
(
np
.
abs
(
output_dict
-
output_keywords
)),
1e-6
)
def
test_attention_outputs
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment