Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d951c14a
Unverified
Commit
d951c14a
authored
Jul 31, 2020
by
Sylvain Gugger
Committed by
GitHub
Jul 31, 2020
Browse files
Model output test (#6155)
* Use return_dict=True in all tests * Formatting
parent
86caab1e
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
222 additions
and
575 deletions
+222
-575
src/transformers/modeling_encoder_decoder.py
src/transformers/modeling_encoder_decoder.py
+2
-0
src/transformers/modeling_openai.py
src/transformers/modeling_openai.py
+3
-4
src/transformers/modeling_reformer.py
src/transformers/modeling_reformer.py
+1
-0
templates/adding_a_new_model/tests/test_modeling_xxx.py
templates/adding_a_new_model/tests/test_modeling_xxx.py
+11
-36
tests/test_modeling_albert.py
tests/test_modeling_albert.py
+15
-51
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+13
-15
tests/test_modeling_bert.py
tests/test_modeling_bert.py
+27
-86
tests/test_modeling_camembert.py
tests/test_modeling_camembert.py
+2
-2
tests/test_modeling_common.py
tests/test_modeling_common.py
+0
-1
tests/test_modeling_ctrl.py
tests/test_modeling_ctrl.py
+8
-16
tests/test_modeling_distilbert.py
tests/test_modeling_distilbert.py
+10
-34
tests/test_modeling_dpr.py
tests/test_modeling_dpr.py
+10
-22
tests/test_modeling_electra.py
tests/test_modeling_electra.py
+12
-48
tests/test_modeling_flaubert.py
tests/test_modeling_flaubert.py
+19
-58
tests/test_modeling_gpt2.py
tests/test_modeling_gpt2.py
+17
-26
tests/test_modeling_longformer.py
tests/test_modeling_longformer.py
+19
-56
tests/test_modeling_mbart.py
tests/test_modeling_mbart.py
+3
-2
tests/test_modeling_mobilebert.py
tests/test_modeling_mobilebert.py
+23
-66
tests/test_modeling_openai.py
tests/test_modeling_openai.py
+10
-16
tests/test_modeling_reformer.py
tests/test_modeling_reformer.py
+17
-36
No files found.
src/transformers/modeling_encoder_decoder.py
View file @
d951c14a
...
@@ -273,6 +273,7 @@ class EncoderDecoderModel(PreTrainedModel):
...
@@ -273,6 +273,7 @@ class EncoderDecoderModel(PreTrainedModel):
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
return_dict
=
False
,
**
kwargs_encoder
,
**
kwargs_encoder
,
)
)
...
@@ -287,6 +288,7 @@ class EncoderDecoderModel(PreTrainedModel):
...
@@ -287,6 +288,7 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attention_mask
=
attention_mask
,
encoder_attention_mask
=
attention_mask
,
head_mask
=
decoder_head_mask
,
head_mask
=
decoder_head_mask
,
labels
=
labels
,
labels
=
labels
,
return_dict
=
False
,
**
kwargs_decoder
,
**
kwargs_decoder
,
)
)
...
...
src/transformers/modeling_openai.py
View file @
d951c14a
...
@@ -688,16 +688,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -688,16 +688,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
).
squeeze
(
-
1
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
).
squeeze
(
-
1
)
lm_loss
=
None
lm_loss
,
mc_loss
=
None
,
None
if
mc_labels
is
not
None
:
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
lm_loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
mc_loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
mc_loss
=
None
if
labels
is
not
None
:
if
labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
m
c
_loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
l
m_loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
if
not
return_dict
:
if
not
return_dict
:
output
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
output
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
...
...
src/transformers/modeling_reformer.py
View file @
d951c14a
...
@@ -2386,6 +2386,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
...
@@ -2386,6 +2386,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
self
.
reformer
(
outputs
=
self
.
reformer
(
input_ids
,
input_ids
,
...
...
templates/adding_a_new_model/tests/test_modeling_xxx.py
View file @
d951c14a
...
@@ -121,6 +121,7 @@ class XxxModelTester:
...
@@ -121,6 +121,7 @@ class XxxModelTester:
max_position_embeddings
=
self
.
max_position_embeddings
,
max_position_embeddings
=
self
.
max_position_embeddings
,
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
,
initializer_range
=
self
.
initializer_range
,
return_dict
=
True
,
)
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -134,18 +135,13 @@ class XxxModelTester:
...
@@ -134,18 +135,13 @@ class XxxModelTester:
model
=
XxxModel
(
config
=
config
)
model
=
XxxModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
d
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
r
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_xxx_for_masked_lm
(
def
create_and_check_xxx_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -153,16 +149,10 @@ class XxxModelTester:
...
@@ -153,16 +149,10 @@ class XxxModelTester:
model
=
XxxForMaskedLM
(
config
=
config
)
model
=
XxxForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_sco
res
=
model
(
res
ult
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
masked_lm_labels
=
token_labels
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
masked_lm_labels
=
token_labels
)
)
result
=
{
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_xxx_for_question_answering
(
def
create_and_check_xxx_for_question_answering
(
...
@@ -171,18 +161,13 @@ class XxxModelTester:
...
@@ -171,18 +161,13 @@ class XxxModelTester:
model
=
XxxForQuestionAnswering
(
config
=
config
)
model
=
XxxForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
result
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
start_positions
=
sequence_labels
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
)
result
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -194,13 +179,7 @@ class XxxModelTester:
...
@@ -194,13 +179,7 @@ class XxxModelTester:
model
=
XxxForSequenceClassification
(
config
)
model
=
XxxForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -211,11 +190,7 @@ class XxxModelTester:
...
@@ -211,11 +190,7 @@ class XxxModelTester:
model
=
XxxForTokenClassification
(
config
=
config
)
model
=
XxxForTokenClassification
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
...
tests/test_modeling_albert.py
View file @
d951c14a
...
@@ -98,6 +98,7 @@ class AlbertModelTester:
...
@@ -98,6 +98,7 @@ class AlbertModelTester:
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
,
initializer_range
=
self
.
initializer_range
,
num_hidden_groups
=
self
.
num_hidden_groups
,
num_hidden_groups
=
self
.
num_hidden_groups
,
return_dict
=
True
,
)
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -111,18 +112,13 @@ class AlbertModelTester:
...
@@ -111,18 +112,13 @@ class AlbertModelTester:
model
=
AlbertModel
(
config
=
config
)
model
=
AlbertModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
d
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
r
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_albert_for_pretraining
(
def
create_and_check_albert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -130,22 +126,17 @@ class AlbertModelTester:
...
@@ -130,22 +126,17 @@ class AlbertModelTester:
model
=
AlbertForPreTraining
(
config
=
config
)
model
=
AlbertForPreTraining
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
,
sop_sco
res
=
model
(
res
ult
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
,
labels
=
token_labels
,
sentence_order_label
=
sequence_labels
,
sentence_order_label
=
sequence_labels
,
)
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"sop_scores"
:
sop_scores
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_
score
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
list
(
result
[
"prediction_
logit
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"sop_
score
s"
].
size
()),
[
self
.
batch_size
,
config
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"sop_
logit
s"
].
size
()),
[
self
.
batch_size
,
config
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_albert_for_masked_lm
(
def
create_and_check_albert_for_masked_lm
(
...
@@ -154,16 +145,8 @@ class AlbertModelTester:
...
@@ -154,16 +145,8 @@ class AlbertModelTester:
model
=
AlbertForMaskedLM
(
config
=
config
)
model
=
AlbertForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_albert_for_question_answering
(
def
create_and_check_albert_for_question_answering
(
...
@@ -172,18 +155,13 @@ class AlbertModelTester:
...
@@ -172,18 +155,13 @@ class AlbertModelTester:
model
=
AlbertForQuestionAnswering
(
config
=
config
)
model
=
AlbertForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
result
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
start_positions
=
sequence_labels
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
)
result
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -195,13 +173,7 @@ class AlbertModelTester:
...
@@ -195,13 +173,7 @@ class AlbertModelTester:
model
=
AlbertForSequenceClassification
(
config
)
model
=
AlbertForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -212,11 +184,7 @@ class AlbertModelTester:
...
@@ -212,11 +184,7 @@ class AlbertModelTester:
model
=
AlbertForTokenClassification
(
config
=
config
)
model
=
AlbertForTokenClassification
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -230,16 +198,12 @@ class AlbertModelTester:
...
@@ -230,16 +198,12 @@ class AlbertModelTester:
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
,
logits
=
model
(
result
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
attention_mask
=
multiple_choice_input_mask
,
token_type_ids
=
multiple_choice_token_type_ids
,
token_type_ids
=
multiple_choice_token_type_ids
,
labels
=
choice_labels
,
labels
=
choice_labels
,
)
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
...
...
tests/test_modeling_bart.py
View file @
d951c14a
...
@@ -238,6 +238,7 @@ class BartHeadTests(unittest.TestCase):
...
@@ -238,6 +238,7 @@ class BartHeadTests(unittest.TestCase):
eos_token_id
=
2
,
eos_token_id
=
2
,
pad_token_id
=
1
,
pad_token_id
=
1
,
bos_token_id
=
0
,
bos_token_id
=
0
,
return_dict
=
True
,
)
)
return
config
,
input_ids
,
batch_size
return
config
,
input_ids
,
batch_size
...
@@ -247,24 +248,20 @@ class BartHeadTests(unittest.TestCase):
...
@@ -247,24 +248,20 @@ class BartHeadTests(unittest.TestCase):
model
=
BartForSequenceClassification
(
config
)
model
=
BartForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
outputs
=
model
(
input_ids
=
input_ids
,
decoder_input_ids
=
input_ids
,
labels
=
labels
)
outputs
=
model
(
input_ids
=
input_ids
,
decoder_input_ids
=
input_ids
,
labels
=
labels
)
logits
=
outputs
[
1
]
expected_shape
=
torch
.
Size
((
batch_size
,
config
.
num_labels
))
expected_shape
=
torch
.
Size
((
batch_size
,
config
.
num_labels
))
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
outputs
[
"logits"
].
shape
,
expected_shape
)
loss
=
outputs
[
0
]
self
.
assertIsInstance
(
outputs
[
"loss"
].
item
(),
float
)
self
.
assertIsInstance
(
loss
.
item
(),
float
)
def
test_question_answering_forward
(
self
):
def
test_question_answering_forward
(
self
):
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
()
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
()
sequence_labels
=
ids_tensor
([
batch_size
],
2
).
to
(
torch_device
)
sequence_labels
=
ids_tensor
([
batch_size
],
2
).
to
(
torch_device
)
model
=
BartForQuestionAnswering
(
config
)
model
=
BartForQuestionAnswering
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
loss
,
start_logits
,
end_logits
,
_
=
model
(
outputs
=
model
(
input_ids
=
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,)
input_ids
=
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
self
.
assertEqual
(
start_logits
.
shape
,
input_ids
.
shape
)
self
.
assertEqual
(
outputs
[
"
start_logits
"
]
.
shape
,
input_ids
.
shape
)
self
.
assertEqual
(
end_logits
.
shape
,
input_ids
.
shape
)
self
.
assertEqual
(
outputs
[
"
end_logits
"
]
.
shape
,
input_ids
.
shape
)
self
.
assertIsInstance
(
loss
.
item
(),
float
)
self
.
assertIsInstance
(
outputs
[
"
loss
"
]
.
item
(),
float
)
@
timeout_decorator
.
timeout
(
1
)
@
timeout_decorator
.
timeout
(
1
)
def
test_lm_forward
(
self
):
def
test_lm_forward
(
self
):
...
@@ -272,10 +269,10 @@ class BartHeadTests(unittest.TestCase):
...
@@ -272,10 +269,10 @@ class BartHeadTests(unittest.TestCase):
lm_labels
=
ids_tensor
([
batch_size
,
input_ids
.
shape
[
1
]],
self
.
vocab_size
).
to
(
torch_device
)
lm_labels
=
ids_tensor
([
batch_size
,
input_ids
.
shape
[
1
]],
self
.
vocab_size
).
to
(
torch_device
)
lm_model
=
BartForConditionalGeneration
(
config
)
lm_model
=
BartForConditionalGeneration
(
config
)
lm_model
.
to
(
torch_device
)
lm_model
.
to
(
torch_device
)
loss
,
logits
,
enc_feature
s
=
lm_model
(
input_ids
=
input_ids
,
labels
=
lm_labels
)
output
s
=
lm_model
(
input_ids
=
input_ids
,
labels
=
lm_labels
)
expected_shape
=
(
batch_size
,
input_ids
.
shape
[
1
],
config
.
vocab_size
)
expected_shape
=
(
batch_size
,
input_ids
.
shape
[
1
],
config
.
vocab_size
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
outputs
[
"
logits
"
]
.
shape
,
expected_shape
)
self
.
assertIsInstance
(
loss
.
item
(),
float
)
self
.
assertIsInstance
(
outputs
[
"
loss
"
]
.
item
(),
float
)
def
test_lm_uneven_forward
(
self
):
def
test_lm_uneven_forward
(
self
):
config
=
BartConfig
(
config
=
BartConfig
(
...
@@ -288,13 +285,14 @@ class BartHeadTests(unittest.TestCase):
...
@@ -288,13 +285,14 @@ class BartHeadTests(unittest.TestCase):
encoder_ffn_dim
=
8
,
encoder_ffn_dim
=
8
,
decoder_ffn_dim
=
8
,
decoder_ffn_dim
=
8
,
max_position_embeddings
=
48
,
max_position_embeddings
=
48
,
return_dict
=
True
,
)
)
lm_model
=
BartForConditionalGeneration
(
config
).
to
(
torch_device
)
lm_model
=
BartForConditionalGeneration
(
config
).
to
(
torch_device
)
context
=
torch
.
Tensor
([[
71
,
82
,
18
,
33
,
46
,
91
,
2
],
[
68
,
34
,
26
,
58
,
30
,
2
,
1
]]).
long
().
to
(
torch_device
)
context
=
torch
.
Tensor
([[
71
,
82
,
18
,
33
,
46
,
91
,
2
],
[
68
,
34
,
26
,
58
,
30
,
2
,
1
]]).
long
().
to
(
torch_device
)
summary
=
torch
.
Tensor
([[
82
,
71
,
82
,
18
,
2
],
[
58
,
68
,
2
,
1
,
1
]]).
long
().
to
(
torch_device
)
summary
=
torch
.
Tensor
([[
82
,
71
,
82
,
18
,
2
],
[
58
,
68
,
2
,
1
,
1
]]).
long
().
to
(
torch_device
)
loss
,
logits
,
enc_feature
s
=
lm_model
(
input_ids
=
context
,
decoder_input_ids
=
summary
,
labels
=
summary
)
output
s
=
lm_model
(
input_ids
=
context
,
decoder_input_ids
=
summary
,
labels
=
summary
)
expected_shape
=
(
*
summary
.
shape
,
config
.
vocab_size
)
expected_shape
=
(
*
summary
.
shape
,
config
.
vocab_size
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
outputs
[
"
logits
"
]
.
shape
,
expected_shape
)
def
test_generate_beam_search
(
self
):
def
test_generate_beam_search
(
self
):
input_ids
=
torch
.
Tensor
([[
71
,
82
,
2
],
[
68
,
34
,
2
]]).
long
().
to
(
torch_device
)
input_ids
=
torch
.
Tensor
([[
71
,
82
,
2
],
[
68
,
34
,
2
]]).
long
().
to
(
torch_device
)
...
...
tests/test_modeling_bert.py
View file @
d951c14a
...
@@ -120,6 +120,7 @@ class BertModelTester:
...
@@ -120,6 +120,7 @@ class BertModelTester:
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
is_decoder
=
False
,
is_decoder
=
False
,
initializer_range
=
self
.
initializer_range
,
initializer_range
=
self
.
initializer_range
,
return_dict
=
True
,
)
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -160,18 +161,13 @@ class BertModelTester:
...
@@ -160,18 +161,13 @@ class BertModelTester:
model
=
BertModel
(
config
=
config
)
model
=
BertModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
d
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
r
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_bert_model_as_decoder
(
def
create_and_check_bert_model_as_decoder
(
self
,
self
,
...
@@ -188,29 +184,24 @@ class BertModelTester:
...
@@ -188,29 +184,24 @@ class BertModelTester:
model
=
BertModel
(
config
)
model
=
BertModel
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
sequence_output
,
pooled_outpu
t
=
model
(
resul
t
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
encoder_attention_mask
=
encoder_attention_mask
,
)
)
sequence_output
,
pooled_outpu
t
=
model
(
resul
t
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
)
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
d
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
r
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_bert_for_causal_lm
(
def
create_and_check_bert_for_causal_lm
(
self
,
self
,
...
@@ -227,16 +218,8 @@ class BertModelTester:
...
@@ -227,16 +218,8 @@ class BertModelTester:
model
=
BertLMHeadModel
(
config
=
config
)
model
=
BertLMHeadModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_bert_for_masked_lm
(
def
create_and_check_bert_for_masked_lm
(
...
@@ -245,16 +228,8 @@ class BertModelTester:
...
@@ -245,16 +228,8 @@ class BertModelTester:
model
=
BertForMaskedLM
(
config
=
config
)
model
=
BertForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_bert_model_for_causal_lm_as_decoder
(
def
create_and_check_bert_model_for_causal_lm_as_decoder
(
...
@@ -272,7 +247,7 @@ class BertModelTester:
...
@@ -272,7 +247,7 @@ class BertModelTester:
model
=
BertLMHeadModel
(
config
=
config
)
model
=
BertLMHeadModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_sco
res
=
model
(
res
ult
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
...
@@ -280,20 +255,14 @@ class BertModelTester:
...
@@ -280,20 +255,14 @@ class BertModelTester:
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
encoder_attention_mask
=
encoder_attention_mask
,
)
)
loss
,
prediction_sco
res
=
model
(
res
ult
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
,
labels
=
token_labels
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
)
)
result
=
{
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_bert_for_next_sequence_prediction
(
def
create_and_check_bert_for_next_sequence_prediction
(
...
@@ -302,14 +271,10 @@ class BertModelTester:
...
@@ -302,14 +271,10 @@ class BertModelTester:
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
seq_relationship_score
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
next_sentence_label
=
sequence_labels
,
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
next_sentence_label
=
sequence_labels
,
)
)
result
=
{
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
2
])
"loss"
:
loss
,
"seq_relationship_score"
:
seq_relationship_score
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_bert_for_pretraining
(
def
create_and_check_bert_for_pretraining
(
...
@@ -318,22 +283,17 @@ class BertModelTester:
...
@@ -318,22 +283,17 @@ class BertModelTester:
model
=
BertForPreTraining
(
config
=
config
)
model
=
BertForPreTraining
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
,
seq_relationship_score
=
model
(
result
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
,
labels
=
token_labels
,
next_sentence_label
=
sequence_labels
,
next_sentence_label
=
sequence_labels
,
)
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"seq_relationship_score"
:
seq_relationship_score
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_
score
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
list
(
result
[
"prediction_
logit
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_
score
"
].
size
()),
[
self
.
batch_size
,
2
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_
logits
"
].
size
()),
[
self
.
batch_size
,
2
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_bert_for_question_answering
(
def
create_and_check_bert_for_question_answering
(
...
@@ -342,18 +302,13 @@ class BertModelTester:
...
@@ -342,18 +302,13 @@ class BertModelTester:
model
=
BertForQuestionAnswering
(
config
=
config
)
model
=
BertForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
result
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
start_positions
=
sequence_labels
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
)
result
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -365,13 +320,7 @@ class BertModelTester:
...
@@ -365,13 +320,7 @@ class BertModelTester:
model
=
BertForSequenceClassification
(
config
)
model
=
BertForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -382,11 +331,7 @@ class BertModelTester:
...
@@ -382,11 +331,7 @@ class BertModelTester:
model
=
BertForTokenClassification
(
config
=
config
)
model
=
BertForTokenClassification
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -400,16 +345,12 @@ class BertModelTester:
...
@@ -400,16 +345,12 @@ class BertModelTester:
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
,
logits
=
model
(
result
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
attention_mask
=
multiple_choice_input_mask
,
token_type_ids
=
multiple_choice_token_type_ids
,
token_type_ids
=
multiple_choice_token_type_ids
,
labels
=
choice_labels
,
labels
=
choice_labels
,
)
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
...
tests/test_modeling_camembert.py
View file @
d951c14a
...
@@ -28,13 +28,13 @@ if is_torch_available():
...
@@ -28,13 +28,13 @@ if is_torch_available():
class
CamembertModelIntegrationTest
(
unittest
.
TestCase
):
class
CamembertModelIntegrationTest
(
unittest
.
TestCase
):
@
slow
@
slow
def
test_output_embeds_base_model
(
self
):
def
test_output_embeds_base_model
(
self
):
model
=
CamembertModel
.
from_pretrained
(
"camembert-base"
)
model
=
CamembertModel
.
from_pretrained
(
"camembert-base"
,
return_dict
=
True
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
input_ids
=
torch
.
tensor
(
input_ids
=
torch
.
tensor
(
[[
5
,
121
,
11
,
660
,
16
,
730
,
25543
,
110
,
83
,
6
]],
device
=
torch_device
,
dtype
=
torch
.
long
,
[[
5
,
121
,
11
,
660
,
16
,
730
,
25543
,
110
,
83
,
6
]],
device
=
torch_device
,
dtype
=
torch
.
long
,
)
# J'aime le camembert !
)
# J'aime le camembert !
output
=
model
(
input_ids
)[
0
]
output
=
model
(
input_ids
)[
"last_hidden_state"
]
expected_shape
=
torch
.
Size
((
1
,
10
,
768
))
expected_shape
=
torch
.
Size
((
1
,
10
,
768
))
self
.
assertEqual
(
output
.
shape
,
expected_shape
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
)
# compare the actual values for a slice.
# compare the actual values for a slice.
...
...
tests/test_modeling_common.py
View file @
d951c14a
...
@@ -74,7 +74,6 @@ class ModelTesterMixin:
...
@@ -74,7 +74,6 @@ class ModelTesterMixin:
def
test_save_load
(
self
):
def
test_save_load
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
.
return_dict
=
True
for
model_class
in
self
.
all_model_classes
:
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
=
model_class
(
config
)
...
...
tests/test_modeling_ctrl.py
View file @
d951c14a
...
@@ -88,9 +88,10 @@ class CTRLModelTester:
...
@@ -88,9 +88,10 @@ class CTRLModelTester:
# hidden_dropout_prob=self.hidden_dropout_prob,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions
=
self
.
max_position_embeddings
,
n_positions
=
self
.
max_position_embeddings
,
n_ctx
=
self
.
max_position_embeddings
n_ctx
=
self
.
max_position_embeddings
,
# type_vocab_size=self.type_vocab_size,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
# initializer_range=self.initializer_range,
return_dict
=
True
,
)
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
...
@@ -117,29 +118,20 @@ class CTRLModelTester:
...
@@ -117,29 +118,20 @@ class CTRLModelTester:
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
sequence_output
,
presents
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"presents"
:
presents
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
self
.
parent
.
assertEqual
(
len
(
result
[
"p
resent
s"
]),
config
.
n_layer
)
self
.
parent
.
assertEqual
(
len
(
result
[
"p
ast_key_value
s"
]),
config
.
n_layer
)
def
create_and_check_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
CTRLLMHeadModel
(
config
)
model
=
CTRLLMHeadModel
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
lm_logits
,
_
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
input_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
input_ids
)
result
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
...
...
tests/test_modeling_distilbert.py
View file @
d951c14a
...
@@ -110,6 +110,7 @@ if is_torch_available():
...
@@ -110,6 +110,7 @@ if is_torch_available():
attention_dropout
=
self
.
attention_probs_dropout_prob
,
attention_dropout
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
max_position_embeddings
=
self
.
max_position_embeddings
,
initializer_range
=
self
.
initializer_range
,
initializer_range
=
self
.
initializer_range
,
return_dict
=
True
,
)
)
return
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -123,14 +124,10 @@ if is_torch_available():
...
@@ -123,14 +124,10 @@ if is_torch_available():
model
=
DistilBertModel
(
config
=
config
)
model
=
DistilBertModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
(
sequence_output
,)
=
model
(
input_ids
,
input_mask
)
result
=
model
(
input_ids
,
input_mask
)
(
sequence_output
,)
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
def
create_and_check_distilbert_for_masked_lm
(
def
create_and_check_distilbert_for_masked_lm
(
...
@@ -139,13 +136,9 @@ if is_torch_available():
...
@@ -139,13 +136,9 @@ if is_torch_available():
model
=
DistilBertForMaskedLM
(
config
=
config
)
model
=
DistilBertForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
prediction_score
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
list
(
result
[
"
logit
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -155,14 +148,9 @@ if is_torch_available():
...
@@ -155,14 +148,9 @@ if is_torch_available():
model
=
DistilBertForQuestionAnswering
(
config
=
config
)
model
=
DistilBertForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
input_ids
,
attention_mask
=
input_mask
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
)
result
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -174,11 +162,7 @@ if is_torch_available():
...
@@ -174,11 +162,7 @@ if is_torch_available():
model
=
DistilBertForSequenceClassification
(
config
)
model
=
DistilBertForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
sequence_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -190,11 +174,7 @@ if is_torch_available():
...
@@ -190,11 +174,7 @@ if is_torch_available():
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
]
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
]
)
)
...
@@ -209,13 +189,9 @@ if is_torch_available():
...
@@ -209,13 +189,9 @@ if is_torch_available():
model
.
eval
()
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
,
logits
=
model
(
result
=
model
(
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
labels
=
choice_labels
,
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
labels
=
choice_labels
,
)
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
...
tests/test_modeling_dpr.py
View file @
d951c14a
...
@@ -115,6 +115,7 @@ class DPRModelTester:
...
@@ -115,6 +115,7 @@ class DPRModelTester:
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
is_decoder
=
False
,
is_decoder
=
False
,
initializer_range
=
self
.
initializer_range
,
initializer_range
=
self
.
initializer_range
,
return_dict
=
True
,
)
)
config
=
DPRConfig
(
projection_dim
=
self
.
projection_dim
,
**
config
.
to_dict
())
config
=
DPRConfig
(
projection_dim
=
self
.
projection_dim
,
**
config
.
to_dict
())
...
@@ -126,15 +127,11 @@ class DPRModelTester:
...
@@ -126,15 +127,11 @@ class DPRModelTester:
model
=
DPRContextEncoder
(
config
=
config
)
model
=
DPRContextEncoder
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
embeddings
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)[
0
]
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
embeddings
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)[
0
]
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
embeddings
=
model
(
input_ids
)[
0
]
result
=
model
(
input_ids
)
result
=
{
"embeddings"
:
embeddings
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
embeddings
"
].
size
()),
[
self
.
batch_size
,
self
.
projection_dim
or
self
.
hidden_size
]
list
(
result
[
"
pooler_output
"
].
size
()),
[
self
.
batch_size
,
self
.
projection_dim
or
self
.
hidden_size
]
)
)
def
create_and_check_dpr_question_encoder
(
def
create_and_check_dpr_question_encoder
(
...
@@ -143,15 +140,11 @@ class DPRModelTester:
...
@@ -143,15 +140,11 @@ class DPRModelTester:
model
=
DPRQuestionEncoder
(
config
=
config
)
model
=
DPRQuestionEncoder
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
embeddings
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)[
0
]
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
embeddings
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)[
0
]
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
embeddings
=
model
(
input_ids
)[
0
]
result
=
model
(
input_ids
)
result
=
{
"embeddings"
:
embeddings
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
embeddings
"
].
size
()),
[
self
.
batch_size
,
self
.
projection_dim
or
self
.
hidden_size
]
list
(
result
[
"
pooler_output
"
].
size
()),
[
self
.
batch_size
,
self
.
projection_dim
or
self
.
hidden_size
]
)
)
def
create_and_check_dpr_reader
(
def
create_and_check_dpr_reader
(
...
@@ -160,12 +153,7 @@ class DPRModelTester:
...
@@ -160,12 +153,7 @@ class DPRModelTester:
model
=
DPRReader
(
config
=
config
)
model
=
DPRReader
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
start_logits
,
end_logits
,
relevance_logits
,
*
_
=
model
(
input_ids
,
attention_mask
=
input_mask
,)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,)
result
=
{
"relevance_logits"
:
relevance_logits
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"relevance_logits"
].
size
()),
[
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"relevance_logits"
].
size
()),
[
self
.
batch_size
])
...
...
tests/test_modeling_electra.py
View file @
d951c14a
...
@@ -97,6 +97,7 @@ class ElectraModelTester:
...
@@ -97,6 +97,7 @@ class ElectraModelTester:
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
is_decoder
=
False
,
is_decoder
=
False
,
initializer_range
=
self
.
initializer_range
,
initializer_range
=
self
.
initializer_range
,
return_dict
=
True
,
)
)
return
(
return
(
...
@@ -127,15 +128,11 @@ class ElectraModelTester:
...
@@ -127,15 +128,11 @@ class ElectraModelTester:
model
=
ElectraModel
(
config
=
config
)
model
=
ElectraModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
(
sequence_output
,)
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
(
sequence_output
,)
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
(
sequence_output
,)
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
def
create_and_check_electra_for_masked_lm
(
def
create_and_check_electra_for_masked_lm
(
...
@@ -152,16 +149,8 @@ class ElectraModelTester:
...
@@ -152,16 +149,8 @@ class ElectraModelTester:
model
=
ElectraForMaskedLM
(
config
=
config
)
model
=
ElectraForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_electra_for_token_classification
(
def
create_and_check_electra_for_token_classification
(
...
@@ -179,11 +168,7 @@ class ElectraModelTester:
...
@@ -179,11 +168,7 @@ class ElectraModelTester:
model
=
ElectraForTokenClassification
(
config
=
config
)
model
=
ElectraForTokenClassification
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -202,13 +187,7 @@ class ElectraModelTester:
...
@@ -202,13 +187,7 @@ class ElectraModelTester:
model
=
ElectraForPreTraining
(
config
=
config
)
model
=
ElectraForPreTraining
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
fake_token_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
fake_token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -227,13 +206,7 @@ class ElectraModelTester:
...
@@ -227,13 +206,7 @@ class ElectraModelTester:
model
=
ElectraForSequenceClassification
(
config
)
model
=
ElectraForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -251,18 +224,13 @@ class ElectraModelTester:
...
@@ -251,18 +224,13 @@ class ElectraModelTester:
model
=
ElectraForQuestionAnswering
(
config
=
config
)
model
=
ElectraForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
result
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
start_positions
=
sequence_labels
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
)
result
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -285,16 +253,12 @@ class ElectraModelTester:
...
@@ -285,16 +253,12 @@ class ElectraModelTester:
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
,
logits
=
model
(
result
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
attention_mask
=
multiple_choice_input_mask
,
token_type_ids
=
multiple_choice_token_type_ids
,
token_type_ids
=
multiple_choice_token_type_ids
,
labels
=
choice_labels
,
labels
=
choice_labels
,
)
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
...
tests/test_modeling_flaubert.py
View file @
d951c14a
...
@@ -110,6 +110,7 @@ class FlaubertModelTester(object):
...
@@ -110,6 +110,7 @@ class FlaubertModelTester(object):
initializer_range
=
self
.
initializer_range
,
initializer_range
=
self
.
initializer_range
,
summary_type
=
self
.
summary_type
,
summary_type
=
self
.
summary_type
,
use_proj
=
self
.
use_proj
,
use_proj
=
self
.
use_proj
,
return_dict
=
True
,
)
)
return
(
return
(
...
@@ -142,15 +143,11 @@ class FlaubertModelTester(object):
...
@@ -142,15 +143,11 @@ class FlaubertModelTester(object):
model
=
FlaubertModel
(
config
=
config
)
model
=
FlaubertModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
result
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
,
langs
=
token_type_ids
)
result
=
model
(
input_ids
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
)
result
=
model
(
input_ids
)
sequence_output
=
outputs
[
0
]
result
=
{
"sequence_output"
:
sequence_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
def
create_and_check_flaubert_lm_head
(
def
create_and_check_flaubert_lm_head
(
...
@@ -169,13 +166,7 @@ class FlaubertModelTester(object):
...
@@ -169,13 +166,7 @@ class FlaubertModelTester(object):
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
...
@@ -195,16 +186,9 @@ class FlaubertModelTester(object):
...
@@ -195,16 +186,9 @@ class FlaubertModelTester(object):
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
outputs
=
model
(
input_ids
)
result
=
model
(
input_ids
)
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
loss
,
start_logits
,
end_logits
=
outputs
result
=
{
result
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -225,10 +209,9 @@ class FlaubertModelTester(object):
...
@@ -225,10 +209,9 @@ class FlaubertModelTester(object):
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
outputs
=
model
(
input_ids
)
result
=
model
(
input_ids
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
=
outputs
output
s
=
model
(
result_with_label
s
=
model
(
input_ids
,
input_ids
,
start_positions
=
sequence_labels
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
...
@@ -237,7 +220,7 @@ class FlaubertModelTester(object):
...
@@ -237,7 +220,7 @@ class FlaubertModelTester(object):
p_mask
=
input_mask
,
p_mask
=
input_mask
,
)
)
output
s
=
model
(
result_with_label
s
=
model
(
input_ids
,
input_ids
,
start_positions
=
sequence_labels
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
...
@@ -245,22 +228,13 @@ class FlaubertModelTester(object):
...
@@ -245,22 +228,13 @@ class FlaubertModelTester(object):
is_impossible
=
is_impossible_labels
,
is_impossible
=
is_impossible_labels
,
)
)
(
total_loss
,)
=
outputs
(
total_loss
,)
=
result_with_labels
.
to_tuple
()
output
s
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
result_with_label
s
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
(
total_loss
,)
=
outputs
(
total_loss
,)
=
result_with_labels
.
to_tuple
()
result
=
{
self
.
parent
.
assertListEqual
(
list
(
result_with_labels
[
"loss"
].
size
()),
[])
"loss"
:
total_loss
,
"start_top_log_probs"
:
start_top_log_probs
,
"start_top_index"
:
start_top_index
,
"end_top_log_probs"
:
end_top_log_probs
,
"end_top_index"
:
end_top_index
,
"cls_logits"
:
cls_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
)
)
...
@@ -292,13 +266,8 @@ class FlaubertModelTester(object):
...
@@ -292,13 +266,8 @@ class FlaubertModelTester(object):
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
(
logits
,)
=
model
(
input_ids
)
result
=
model
(
input_ids
)
loss
,
logits
=
model
(
input_ids
,
labels
=
sequence_labels
)
result
=
model
(
input_ids
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
...
@@ -320,11 +289,7 @@ class FlaubertModelTester(object):
...
@@ -320,11 +289,7 @@ class FlaubertModelTester(object):
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -347,16 +312,12 @@ class FlaubertModelTester(object):
...
@@ -347,16 +312,12 @@ class FlaubertModelTester(object):
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
,
logits
=
model
(
result
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
attention_mask
=
multiple_choice_input_mask
,
token_type_ids
=
multiple_choice_token_type_ids
,
token_type_ids
=
multiple_choice_token_type_ids
,
labels
=
choice_labels
,
labels
=
choice_labels
,
)
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
...
tests/test_modeling_gpt2.py
View file @
d951c14a
...
@@ -122,9 +122,10 @@ class GPT2ModelTester:
...
@@ -122,9 +122,10 @@ class GPT2ModelTester:
n_positions
=
self
.
max_position_embeddings
,
n_positions
=
self
.
max_position_embeddings
,
n_ctx
=
self
.
max_position_embeddings
,
n_ctx
=
self
.
max_position_embeddings
,
# type_vocab_size=self.type_vocab_size,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
# initializer_range=self.initializer_range
,
bos_token_id
=
self
.
bos_token_id
,
bos_token_id
=
self
.
bos_token_id
,
eos_token_id
=
self
.
eos_token_id
,
eos_token_id
=
self
.
eos_token_id
,
return_dict
=
True
,
)
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
...
@@ -149,18 +150,14 @@ class GPT2ModelTester:
...
@@ -149,18 +150,14 @@ class GPT2ModelTester:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
sequence_output
,
presents
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"presents"
:
presents
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
],
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
],
)
)
self
.
parent
.
assertEqual
(
len
(
result
[
"p
resent
s"
]),
config
.
n_layer
)
self
.
parent
.
assertEqual
(
len
(
result
[
"p
ast_key_value
s"
]),
config
.
n_layer
)
def
create_and_check_gpt2_model_past
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_gpt2_model_past
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
GPT2Model
(
config
=
config
)
model
=
GPT2Model
(
config
=
config
)
...
@@ -175,7 +172,7 @@ class GPT2ModelTester:
...
@@ -175,7 +172,7 @@ class GPT2ModelTester:
self
.
parent
.
assertTrue
(
len
(
outputs
)
==
len
(
outputs_use_cache_conf
))
self
.
parent
.
assertTrue
(
len
(
outputs
)
==
len
(
outputs_use_cache_conf
))
self
.
parent
.
assertTrue
(
len
(
outputs
)
==
len
(
outputs_no_past
)
+
1
)
self
.
parent
.
assertTrue
(
len
(
outputs
)
==
len
(
outputs_no_past
)
+
1
)
output
,
past
=
outputs
output
,
past
=
outputs
.
to_tuple
()
# create hypothetical next token and extent to next_input_ids
# create hypothetical next token and extent to next_input_ids
next_tokens
=
ids_tensor
((
self
.
batch_size
,
1
),
config
.
vocab_size
)
next_tokens
=
ids_tensor
((
self
.
batch_size
,
1
),
config
.
vocab_size
)
...
@@ -185,8 +182,8 @@ class GPT2ModelTester:
...
@@ -185,8 +182,8 @@ class GPT2ModelTester:
next_input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
],
dim
=-
1
)
next_input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
],
dim
=-
1
)
next_token_type_ids
=
torch
.
cat
([
token_type_ids
,
next_token_types
],
dim
=-
1
)
next_token_type_ids
=
torch
.
cat
([
token_type_ids
,
next_token_types
],
dim
=-
1
)
output_from_no_past
,
_
=
model
(
next_input_ids
,
token_type_ids
=
next_token_type_ids
)
output_from_no_past
=
model
(
next_input_ids
,
token_type_ids
=
next_token_type_ids
)
[
"last_hidden_state"
]
output_from_past
,
_
=
model
(
next_tokens
,
token_type_ids
=
next_token_types
,
past
=
past
)
output_from_past
=
model
(
next_tokens
,
token_type_ids
=
next_token_types
,
past
=
past
)
[
"last_hidden_state"
]
# select random slice
# select random slice
random_slice_idx
=
ids_tensor
((
1
,),
output_from_past
.
shape
[
-
1
]).
item
()
random_slice_idx
=
ids_tensor
((
1
,),
output_from_past
.
shape
[
-
1
]).
item
()
...
@@ -209,7 +206,7 @@ class GPT2ModelTester:
...
@@ -209,7 +206,7 @@ class GPT2ModelTester:
attn_mask
[:,
half_seq_length
:]
=
0
attn_mask
[:,
half_seq_length
:]
=
0
# first forward pass
# first forward pass
output
,
past
=
model
(
input_ids
,
attention_mask
=
attn_mask
)
output
,
past
=
model
(
input_ids
,
attention_mask
=
attn_mask
)
.
to_tuple
()
# create hypothetical next token and extent to next_input_ids
# create hypothetical next token and extent to next_input_ids
next_tokens
=
ids_tensor
((
self
.
batch_size
,
1
),
config
.
vocab_size
)
next_tokens
=
ids_tensor
((
self
.
batch_size
,
1
),
config
.
vocab_size
)
...
@@ -226,8 +223,8 @@ class GPT2ModelTester:
...
@@ -226,8 +223,8 @@ class GPT2ModelTester:
)
)
# get two different outputs
# get two different outputs
output_from_no_past
,
_
=
model
(
next_input_ids
,
attention_mask
=
attn_mask
)
output_from_no_past
=
model
(
next_input_ids
,
attention_mask
=
attn_mask
)
[
"last_hidden_state"
]
output_from_past
,
_
=
model
(
next_tokens
,
past
=
past
,
attention_mask
=
attn_mask
)
output_from_past
=
model
(
next_tokens
,
past
=
past
,
attention_mask
=
attn_mask
)
[
"last_hidden_state"
]
# select random slice
# select random slice
random_slice_idx
=
ids_tensor
((
1
,),
output_from_past
.
shape
[
-
1
]).
item
()
random_slice_idx
=
ids_tensor
((
1
,),
output_from_past
.
shape
[
-
1
]).
item
()
...
@@ -242,13 +239,10 @@ class GPT2ModelTester:
...
@@ -242,13 +239,10 @@ class GPT2ModelTester:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
lm_logits
,
_
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
input_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
input_ids
)
result
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
lm_
logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
)
)
def
create_and_check_double_lm_head_model
(
def
create_and_check_double_lm_head_model
(
...
@@ -270,11 +264,8 @@ class GPT2ModelTester:
...
@@ -270,11 +264,8 @@ class GPT2ModelTester:
"labels"
:
multiple_choice_inputs_ids
,
"labels"
:
multiple_choice_inputs_ids
,
}
}
loss
,
lm_logits
,
mc_logits
,
_
=
model
(
**
inputs
)
result
=
model
(
**
inputs
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_loss"
].
size
()),
[])
result
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"mc_logits"
:
mc_logits
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
,
self
.
seq_length
,
self
.
vocab_size
],
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
,
self
.
seq_length
,
self
.
vocab_size
],
)
)
...
...
tests/test_modeling_longformer.py
View file @
d951c14a
...
@@ -108,6 +108,7 @@ class LongformerModelTester:
...
@@ -108,6 +108,7 @@ class LongformerModelTester:
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
,
initializer_range
=
self
.
initializer_range
,
attention_window
=
self
.
attention_window
,
attention_window
=
self
.
attention_window
,
return_dict
=
True
,
)
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -123,8 +124,8 @@ class LongformerModelTester:
...
@@ -123,8 +124,8 @@ class LongformerModelTester:
model
.
eval
()
model
.
eval
()
attention_mask
=
torch
.
ones
(
input_ids
.
shape
,
dtype
=
torch
.
long
,
device
=
torch_device
)
attention_mask
=
torch
.
ones
(
input_ids
.
shape
,
dtype
=
torch
.
long
,
device
=
torch_device
)
output_with_mask
=
model
(
input_ids
,
attention_mask
=
attention_mask
)[
0
]
output_with_mask
=
model
(
input_ids
,
attention_mask
=
attention_mask
)[
"last_hidden_state"
]
output_without_mask
=
model
(
input_ids
)[
0
]
output_without_mask
=
model
(
input_ids
)[
"last_hidden_state"
]
self
.
parent
.
assertTrue
(
torch
.
allclose
(
output_with_mask
[
0
,
0
,
:
5
],
output_without_mask
[
0
,
0
,
:
5
],
atol
=
1e-4
))
self
.
parent
.
assertTrue
(
torch
.
allclose
(
output_with_mask
[
0
,
0
,
:
5
],
output_without_mask
[
0
,
0
,
:
5
],
atol
=
1e-4
))
def
create_and_check_longformer_model
(
def
create_and_check_longformer_model
(
...
@@ -133,18 +134,13 @@ class LongformerModelTester:
...
@@ -133,18 +134,13 @@ class LongformerModelTester:
model
=
LongformerModel
(
config
=
config
)
model
=
LongformerModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
d
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
r
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_longformer_model_with_global_attention_mask
(
def
create_and_check_longformer_model_with_global_attention_mask
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -156,25 +152,19 @@ class LongformerModelTester:
...
@@ -156,25 +152,19 @@ class LongformerModelTester:
global_attention_mask
[:,
input_mask
.
shape
[
-
1
]
//
2
]
=
0
global_attention_mask
[:,
input_mask
.
shape
[
-
1
]
//
2
]
=
0
global_attention_mask
=
global_attention_mask
.
to
(
torch_device
)
global_attention_mask
=
global_attention_mask
.
to
(
torch_device
)
sequence_output
,
pooled_outpu
t
=
model
(
resul
t
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
global_attention_mask
=
global_attention_mask
,
global_attention_mask
=
global_attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
)
)
sequence_output
,
pooled_output
=
model
(
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
global_attention_mask
=
global_attention_mask
)
input_ids
,
token_type_ids
=
token_type_ids
,
global_attention_mask
=
global_attention_mask
result
=
model
(
input_ids
,
global_attention_mask
=
global_attention_mask
)
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
global_attention_mask
=
global_attention_mask
)
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
d
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
r
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_longformer_for_masked_lm
(
def
create_and_check_longformer_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -182,16 +172,8 @@ class LongformerModelTester:
...
@@ -182,16 +172,8 @@ class LongformerModelTester:
model
=
LongformerForMaskedLM
(
config
=
config
)
model
=
LongformerForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_longformer_for_question_answering
(
def
create_and_check_longformer_for_question_answering
(
...
@@ -200,7 +182,7 @@ class LongformerModelTester:
...
@@ -200,7 +182,7 @@ class LongformerModelTester:
model
=
LongformerForQuestionAnswering
(
config
=
config
)
model
=
LongformerForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
result
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
global_attention_mask
=
input_mask
,
global_attention_mask
=
input_mask
,
...
@@ -208,11 +190,6 @@ class LongformerModelTester:
...
@@ -208,11 +190,6 @@ class LongformerModelTester:
start_positions
=
sequence_labels
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
)
result
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -224,13 +201,7 @@ class LongformerModelTester:
...
@@ -224,13 +201,7 @@ class LongformerModelTester:
model
=
LongformerForSequenceClassification
(
config
)
model
=
LongformerForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -241,11 +212,7 @@ class LongformerModelTester:
...
@@ -241,11 +212,7 @@ class LongformerModelTester:
model
=
LongformerForTokenClassification
(
config
=
config
)
model
=
LongformerForTokenClassification
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -260,17 +227,13 @@ class LongformerModelTester:
...
@@ -260,17 +227,13 @@ class LongformerModelTester:
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
,
logits
=
model
(
result
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
attention_mask
=
multiple_choice_input_mask
,
global_attention_mask
=
multiple_choice_input_mask
,
global_attention_mask
=
multiple_choice_input_mask
,
token_type_ids
=
multiple_choice_token_type_ids
,
token_type_ids
=
multiple_choice_token_type_ids
,
labels
=
choice_labels
,
labels
=
choice_labels
,
)
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
...
tests/test_modeling_mbart.py
View file @
d951c14a
...
@@ -114,13 +114,14 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
...
@@ -114,13 +114,14 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
decoder_ffn_dim
=
32
,
decoder_ffn_dim
=
32
,
max_position_embeddings
=
48
,
max_position_embeddings
=
48
,
add_final_layer_norm
=
True
,
add_final_layer_norm
=
True
,
return_dict
=
True
,
)
)
lm_model
=
BartForConditionalGeneration
(
config
).
to
(
torch_device
)
lm_model
=
BartForConditionalGeneration
(
config
).
to
(
torch_device
)
context
=
torch
.
Tensor
([[
71
,
82
,
18
,
33
,
46
,
91
,
2
],
[
68
,
34
,
26
,
58
,
30
,
2
,
1
]]).
long
().
to
(
torch_device
)
context
=
torch
.
Tensor
([[
71
,
82
,
18
,
33
,
46
,
91
,
2
],
[
68
,
34
,
26
,
58
,
30
,
2
,
1
]]).
long
().
to
(
torch_device
)
summary
=
torch
.
Tensor
([[
82
,
71
,
82
,
18
,
2
],
[
58
,
68
,
2
,
1
,
1
]]).
long
().
to
(
torch_device
)
summary
=
torch
.
Tensor
([[
82
,
71
,
82
,
18
,
2
],
[
58
,
68
,
2
,
1
,
1
]]).
long
().
to
(
torch_device
)
loss
,
logits
,
enc_featu
res
=
lm_model
(
input_ids
=
context
,
decoder_input_ids
=
summary
,
labels
=
summary
)
res
ult
=
lm_model
(
input_ids
=
context
,
decoder_input_ids
=
summary
,
labels
=
summary
)
expected_shape
=
(
*
summary
.
shape
,
config
.
vocab_size
)
expected_shape
=
(
*
summary
.
shape
,
config
.
vocab_size
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
result
[
"
logits
"
]
.
shape
,
expected_shape
)
@
require_torch
@
require_torch
...
...
tests/test_modeling_mobilebert.py
View file @
d951c14a
...
@@ -122,6 +122,7 @@ class MobileBertModelTester:
...
@@ -122,6 +122,7 @@ class MobileBertModelTester:
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
is_decoder
=
False
,
is_decoder
=
False
,
initializer_range
=
self
.
initializer_range
,
initializer_range
=
self
.
initializer_range
,
return_dict
=
True
,
)
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -162,18 +163,14 @@ class MobileBertModelTester:
...
@@ -162,18 +163,14 @@ class MobileBertModelTester:
model
=
MobileBertModel
(
config
=
config
)
model
=
MobileBertModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
d
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
r
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_mobilebert_model_as_decoder
(
def
create_and_check_mobilebert_model_as_decoder
(
self
,
self
,
...
@@ -190,29 +187,25 @@ class MobileBertModelTester:
...
@@ -190,29 +187,25 @@ class MobileBertModelTester:
model
=
MobileBertModel
(
config
)
model
=
MobileBertModel
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
sequence_output
,
pooled_outpu
t
=
model
(
resul
t
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
encoder_attention_mask
=
encoder_attention_mask
,
)
)
sequence_output
,
pooled_outpu
t
=
model
(
resul
t
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
)
)
sequence_output
,
pooled_outpu
t
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
resul
t
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
d
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"poole
r
_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_mobilebert_for_masked_lm
(
def
create_and_check_mobilebert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
@@ -220,16 +213,8 @@ class MobileBertModelTester:
...
@@ -220,16 +213,8 @@ class MobileBertModelTester:
model
=
MobileBertForMaskedLM
(
config
=
config
)
model
=
MobileBertForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_mobilebert_for_next_sequence_prediction
(
def
create_and_check_mobilebert_for_next_sequence_prediction
(
...
@@ -238,14 +223,10 @@ class MobileBertModelTester:
...
@@ -238,14 +223,10 @@ class MobileBertModelTester:
model
=
MobileBertForNextSentencePrediction
(
config
=
config
)
model
=
MobileBertForNextSentencePrediction
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
seq_relationship_score
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
next_sentence_label
=
sequence_labels
,
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
next_sentence_label
=
sequence_labels
,
)
)
result
=
{
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
2
])
"loss"
:
loss
,
"seq_relationship_score"
:
seq_relationship_score
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_mobilebert_for_pretraining
(
def
create_and_check_mobilebert_for_pretraining
(
...
@@ -254,22 +235,17 @@ class MobileBertModelTester:
...
@@ -254,22 +235,17 @@ class MobileBertModelTester:
model
=
MobileBertForPreTraining
(
config
=
config
)
model
=
MobileBertForPreTraining
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
,
seq_relationship_score
=
model
(
result
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
,
labels
=
token_labels
,
next_sentence_label
=
sequence_labels
,
next_sentence_label
=
sequence_labels
,
)
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"seq_relationship_score"
:
seq_relationship_score
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_
score
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
list
(
result
[
"prediction_
logit
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_
score
"
].
size
()),
[
self
.
batch_size
,
2
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_
logits
"
].
size
()),
[
self
.
batch_size
,
2
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
def
create_and_check_mobilebert_for_question_answering
(
def
create_and_check_mobilebert_for_question_answering
(
...
@@ -278,18 +254,13 @@ class MobileBertModelTester:
...
@@ -278,18 +254,13 @@ class MobileBertModelTester:
model
=
MobileBertForQuestionAnswering
(
config
=
config
)
model
=
MobileBertForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
result
=
model
(
input_ids
,
input_ids
,
attention_mask
=
input_mask
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
start_positions
=
sequence_labels
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
)
result
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -301,13 +272,7 @@ class MobileBertModelTester:
...
@@ -301,13 +272,7 @@ class MobileBertModelTester:
model
=
MobileBertForSequenceClassification
(
config
)
model
=
MobileBertForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -318,11 +283,7 @@ class MobileBertModelTester:
...
@@ -318,11 +283,7 @@ class MobileBertModelTester:
model
=
MobileBertForTokenClassification
(
config
=
config
)
model
=
MobileBertForTokenClassification
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -336,16 +297,12 @@ class MobileBertModelTester:
...
@@ -336,16 +297,12 @@ class MobileBertModelTester:
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
,
logits
=
model
(
result
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
attention_mask
=
multiple_choice_input_mask
,
token_type_ids
=
multiple_choice_token_type_ids
,
token_type_ids
=
multiple_choice_token_type_ids
,
labels
=
choice_labels
,
labels
=
choice_labels
,
)
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
...
tests/test_modeling_openai.py
View file @
d951c14a
...
@@ -85,9 +85,10 @@ class OpenAIGPTModelTester:
...
@@ -85,9 +85,10 @@ class OpenAIGPTModelTester:
# hidden_dropout_prob=self.hidden_dropout_prob,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions
=
self
.
max_position_embeddings
,
n_positions
=
self
.
max_position_embeddings
,
n_ctx
=
self
.
max_position_embeddings
n_ctx
=
self
.
max_position_embeddings
,
# type_vocab_size=self.type_vocab_size,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
# initializer_range=self.initializer_range
return_dict
=
True
,
)
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
...
@@ -110,13 +111,12 @@ class OpenAIGPTModelTester:
...
@@ -110,13 +111,12 @@ class OpenAIGPTModelTester:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
(
sequence_output
,)
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
],
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
],
)
)
def
create_and_check_lm_head_model
(
self
,
config
,
input_ids
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_lm_head_model
(
self
,
config
,
input_ids
,
head_mask
,
token_type_ids
,
*
args
):
...
@@ -124,13 +124,10 @@ class OpenAIGPTModelTester:
...
@@ -124,13 +124,10 @@ class OpenAIGPTModelTester:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
lm_logits
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
input_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
input_ids
)
result
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
lm_
logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
)
)
def
create_and_check_double_lm_head_model
(
self
,
config
,
input_ids
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_double_lm_head_model
(
self
,
config
,
input_ids
,
head_mask
,
token_type_ids
,
*
args
):
...
@@ -138,11 +135,8 @@ class OpenAIGPTModelTester:
...
@@ -138,11 +135,8 @@ class OpenAIGPTModelTester:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
lm_logits
,
mc_logits
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
input_ids
)
result
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
input_ids
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_loss"
].
size
()),
[])
result
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
)
)
...
...
tests/test_modeling_reformer.py
View file @
d951c14a
...
@@ -165,6 +165,7 @@ class ReformerModelTester:
...
@@ -165,6 +165,7 @@ class ReformerModelTester:
attn_layers
=
self
.
attn_layers
,
attn_layers
=
self
.
attn_layers
,
pad_token_id
=
self
.
pad_token_id
,
pad_token_id
=
self
.
pad_token_id
,
hash_seed
=
self
.
hash_seed
,
hash_seed
=
self
.
hash_seed
,
return_dict
=
True
,
)
)
return
(
return
(
...
@@ -181,15 +182,12 @@ class ReformerModelTester:
...
@@ -181,15 +182,12 @@ class ReformerModelTester:
model
=
ReformerModel
(
config
=
config
)
model
=
ReformerModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
sequence_output
,
_
=
model
(
input_ids
,
attention_mask
=
input_mask
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
)
sequence_output
,
_
=
model
(
input_ids
)
result
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
}
# 2 * hidden_size because we use reversible resnet layers
# 2 * hidden_size because we use reversible resnet layers
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
sequence_output
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
2
*
self
.
hidden_size
],
list
(
result
[
"
last_hidden_state
"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
2
*
self
.
hidden_size
],
)
)
def
create_and_check_reformer_model_with_lm_backward
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
def
create_and_check_reformer_model_with_lm_backward
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
...
@@ -198,7 +196,7 @@ class ReformerModelTester:
...
@@ -198,7 +196,7 @@ class ReformerModelTester:
model
=
ReformerForMaskedLM
(
config
=
config
)
model
=
ReformerForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
input_ids
)[
0
]
loss
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
input_ids
)[
"loss"
]
loss
.
backward
()
loss
.
backward
()
def
create_and_check_reformer_with_lm
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
def
create_and_check_reformer_with_lm
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
...
@@ -207,13 +205,9 @@ class ReformerModelTester:
...
@@ -207,13 +205,9 @@ class ReformerModelTester:
model
=
ReformerModelWithLMHead
(
config
=
config
)
model
=
ReformerModelWithLMHead
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
,
_
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
input_ids
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
input_ids
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
prediction_score
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
list
(
result
[
"
logit
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
)
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -222,13 +216,9 @@ class ReformerModelTester:
...
@@ -222,13 +216,9 @@ class ReformerModelTester:
model
=
ReformerForMaskedLM
(
config
=
config
)
model
=
ReformerForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
input_ids
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
input_ids
)
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"
prediction_score
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
list
(
result
[
"
logit
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
],
)
)
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -325,7 +315,7 @@ class ReformerModelTester:
...
@@ -325,7 +315,7 @@ class ReformerModelTester:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
hidden_states_with_chunk
=
model
(
input_ids
,
attention_mask
=
input_mask
)[
0
]
hidden_states_with_chunk
=
model
(
input_ids
,
attention_mask
=
input_mask
)[
"last_hidden_state"
]
self
.
parent
.
assertTrue
(
torch
.
allclose
(
hidden_states_no_chunk
,
hidden_states_with_chunk
,
atol
=
1e-3
))
self
.
parent
.
assertTrue
(
torch
.
allclose
(
hidden_states_no_chunk
,
hidden_states_with_chunk
,
atol
=
1e-3
))
def
create_and_check_reformer_feed_backward_chunking
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
def
create_and_check_reformer_feed_backward_chunking
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
...
@@ -408,7 +398,7 @@ class ReformerModelTester:
...
@@ -408,7 +398,7 @@ class ReformerModelTester:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
half
()
model
.
half
()
model
.
eval
()
model
.
eval
()
output
=
model
(
input_ids
,
attention_mask
=
input_mask
)[
0
]
output
=
model
(
input_ids
,
attention_mask
=
input_mask
)[
"last_input_state"
]
self
.
parent
.
assertFalse
(
torch
.
isnan
(
output
).
any
().
item
())
self
.
parent
.
assertFalse
(
torch
.
isnan
(
output
).
any
().
item
())
def
create_and_check_reformer_model_generate
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
def
create_and_check_reformer_model_generate
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
...
@@ -444,21 +434,16 @@ class ReformerModelTester:
...
@@ -444,21 +434,16 @@ class ReformerModelTester:
model
=
ReformerForMaskedLM
(
config
=
config
)
model
=
ReformerForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
output_logits
=
model
(
input_ids
,
attention_mask
=
input_mask
)[
0
]
output_logits
=
model
(
input_ids
,
attention_mask
=
input_mask
)[
"logits"
]
self
.
parent
.
assertTrue
(
output_logits
.
shape
[
1
]
==
input_ids
.
shape
[
-
1
])
self
.
parent
.
assertTrue
(
output_logits
.
shape
[
1
]
==
input_ids
.
shape
[
-
1
])
def
create_and_check_reformer_for_question_answering
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
def
create_and_check_reformer_for_question_answering
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
model
=
ReformerForQuestionAnswering
(
config
=
config
)
model
=
ReformerForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
start_positions
=
choice_labels
,
end_positions
=
choice_labels
,
input_ids
,
attention_mask
=
input_mask
,
start_positions
=
choice_labels
,
end_positions
=
choice_labels
,
)
)
result
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
@@ -474,11 +459,11 @@ class ReformerModelTester:
...
@@ -474,11 +459,11 @@ class ReformerModelTester:
input_ids_second
=
input_ids
[:,
-
1
:]
input_ids_second
=
input_ids
[:,
-
1
:]
# return saved cache
# return saved cache
_
,
past_buckets_states
=
model
(
input_ids_first
,
use_cache
=
True
)
past_buckets_states
=
model
(
input_ids_first
,
use_cache
=
True
)
[
"past_buckets_states"
]
# calculate last output with and without cache
# calculate last output with and without cache
outputs_with_cache
,
_
=
model
(
input_ids_second
,
past_buckets_states
=
past_buckets_states
,
use_cache
=
True
)
outputs_with_cache
=
model
(
input_ids_second
,
past_buckets_states
=
past_buckets_states
,
use_cache
=
True
)
[
"logits"
]
outputs_without_cache
=
model
(
input_ids
)[
0
][:,
-
1
]
outputs_without_cache
=
model
(
input_ids
)[
"logits"
][:,
-
1
]
# select random slice idx
# select random slice idx
random_slice_idx
=
torch
.
randint
(
outputs_without_cache
.
shape
[
-
1
],
(
1
,
1
),
device
=
torch_device
).
item
()
random_slice_idx
=
torch
.
randint
(
outputs_without_cache
.
shape
[
-
1
],
(
1
,
1
),
device
=
torch_device
).
item
()
...
@@ -504,11 +489,7 @@ class ReformerModelTester:
...
@@ -504,11 +489,7 @@ class ReformerModelTester:
model
=
ReformerForSequenceClassification
(
config
)
model
=
ReformerForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
sequence_labels
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
self
.
check_loss_output
(
result
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment