Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c3d9ac76
Unverified
Commit
c3d9ac76
authored
Jul 21, 2021
by
Lysandre Debut
Committed by
GitHub
Jul 21, 2021
Browse files
Expose get_config() on ModelTesters (#12812)
* Expose get_config() on ModelTesters * Typo
parent
cabcc751
Changes
53
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
385 additions
and
406 deletions
+385
-406
tests/test_modeling_reformer.py
tests/test_modeling_reformer.py
+84
-122
tests/test_modeling_roberta.py
tests/test_modeling_roberta.py
+7
-5
tests/test_modeling_roformer.py
tests/test_modeling_roformer.py
+7
-5
tests/test_modeling_speech_to_text.py
tests/test_modeling_speech_to_text.py
+13
-15
tests/test_modeling_squeezebert.py
tests/test_modeling_squeezebert.py
+176
-175
tests/test_modeling_t5.py
tests/test_modeling_t5.py
+15
-12
tests/test_modeling_tapas.py
tests/test_modeling_tapas.py
+19
-17
tests/test_modeling_transfo_xl.py
tests/test_modeling_transfo_xl.py
+8
-5
tests/test_modeling_visual_bert.py
tests/test_modeling_visual_bert.py
+4
-6
tests/test_modeling_vit.py
tests/test_modeling_vit.py
+8
-4
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+7
-5
tests/test_modeling_xlm.py
tests/test_modeling_xlm.py
+17
-16
tests/test_modeling_xlnet.py
tests/test_modeling_xlnet.py
+20
-19
No files found.
tests/test_modeling_reformer.py
View file @
c3d9ac76
...
...
@@ -15,7 +15,7 @@
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
ReformerConfig
,
is_torch_available
from
transformers.testing_utils
import
(
require_sentencepiece
,
require_tokenizers
,
...
...
@@ -36,7 +36,6 @@ if is_torch_available():
from
transformers
import
(
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST
,
ReformerConfig
,
ReformerForMaskedLM
,
ReformerForQuestionAnswering
,
ReformerForSequenceClassification
,
...
...
@@ -51,44 +50,44 @@ class ReformerModelTester:
def
__init__
(
self
,
parent
,
batch_size
=
None
,
seq_length
=
None
,
is_training
=
Non
e
,
is_decoder
=
Non
e
,
use_input_mask
=
Non
e
,
use_labels
=
Non
e
,
vocab_size
=
None
,
attention_head_size
=
None
,
hidden_size
=
None
,
num_attention_heads
=
None
,
local_attn_chunk_length
=
None
,
local_num_chunks_before
=
None
,
local_num_chunks_after
=
None
,
batch_size
=
13
,
seq_length
=
32
,
is_training
=
Tru
e
,
is_decoder
=
Tru
e
,
use_input_mask
=
Tru
e
,
use_labels
=
Tru
e
,
vocab_size
=
32
,
attention_head_size
=
16
,
hidden_size
=
32
,
num_attention_heads
=
2
,
local_attn_chunk_length
=
4
,
local_num_chunks_before
=
1
,
local_num_chunks_after
=
0
,
num_buckets
=
None
,
num_hashes
=
1
,
lsh_attn_chunk_length
=
None
,
lsh_num_chunks_before
=
None
,
lsh_num_chunks_after
=
None
,
chunk_size_lm_head
=
None
,
chunk_size_feed_forward
=
None
,
feed_forward_size
=
None
,
hidden_act
=
None
,
hidden_dropout_prob
=
None
,
local_attention_probs_dropout_prob
=
None
,
chunk_size_lm_head
=
0
,
chunk_size_feed_forward
=
0
,
feed_forward_size
=
32
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
local_attention_probs_dropout_prob
=
0.1
,
lsh_attention_probs_dropout_prob
=
None
,
max_position_embeddings
=
None
,
initializer_range
=
None
,
axial_norm_std
=
None
,
layer_norm_eps
=
None
,
axial_pos_embds
=
Non
e
,
axial_pos_shape
=
None
,
axial_pos_embds_dim
=
None
,
attn_layers
=
None
,
pad_token_id
=
None
,
eos_token_id
=
None
,
max_position_embeddings
=
512
,
initializer_range
=
0.02
,
axial_norm_std
=
1.0
,
layer_norm_eps
=
1e-12
,
axial_pos_embds
=
Tru
e
,
axial_pos_shape
=
[
4
,
8
]
,
axial_pos_embds_dim
=
[
16
,
16
]
,
attn_layers
=
[
"local"
,
"local"
,
"local"
,
"local"
]
,
pad_token_id
=
0
,
eos_token_id
=
2
,
scope
=
None
,
hash_seed
=
None
,
num_labels
=
None
,
hash_seed
=
0
,
num_labels
=
2
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -101,7 +100,7 @@ class ReformerModelTester:
self
.
attention_head_size
=
attention_head_size
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_hidden_layers
=
len
(
attn_layers
)
self
.
num_hidden_layers
=
len
(
attn_layers
)
if
attn_layers
is
not
None
else
0
self
.
local_attn_chunk_length
=
local_attn_chunk_length
self
.
local_num_chunks_after
=
local_num_chunks_after
self
.
local_num_chunks_before
=
local_num_chunks_before
...
...
@@ -149,7 +148,17 @@ class ReformerModelTester:
if
self
.
use_labels
:
choice_labels
=
ids_tensor
([
self
.
batch_size
],
2
)
config
=
ReformerConfig
(
config
=
self
.
get_config
()
return
(
config
,
input_ids
,
input_mask
,
choice_labels
,
)
def
get_config
(
self
):
return
ReformerConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -177,13 +186,6 @@ class ReformerModelTester:
hash_seed
=
self
.
hash_seed
,
)
return
(
config
,
input_ids
,
input_mask
,
choice_labels
,
)
def
create_and_check_reformer_model
(
self
,
config
,
input_ids
,
input_mask
,
choice_labels
):
model
=
ReformerModel
(
config
=
config
)
model
.
to
(
torch_device
)
...
...
@@ -593,45 +595,8 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
test_torchscript
=
False
test_sequence_classification_problem_types
=
True
def
prepare_kwargs
(
self
):
return
{
"batch_size"
:
13
,
"seq_length"
:
32
,
"is_training"
:
True
,
"is_decoder"
:
True
,
"use_input_mask"
:
True
,
"use_labels"
:
True
,
"vocab_size"
:
32
,
"attention_head_size"
:
16
,
"hidden_size"
:
32
,
"num_attention_heads"
:
2
,
"local_attn_chunk_length"
:
4
,
"local_num_chunks_before"
:
1
,
"local_num_chunks_after"
:
0
,
"chunk_size_lm_head"
:
0
,
"chunk_size_feed_forward"
:
0
,
"feed_forward_size"
:
32
,
"hidden_act"
:
"gelu"
,
"hidden_dropout_prob"
:
0.1
,
"local_attention_probs_dropout_prob"
:
0.1
,
"max_position_embeddings"
:
512
,
"initializer_range"
:
0.02
,
"axial_norm_std"
:
1.0
,
"layer_norm_eps"
:
1e-12
,
"axial_pos_embds"
:
True
,
"axial_pos_shape"
:
[
4
,
8
],
"axial_pos_embds_dim"
:
[
16
,
16
],
"attn_layers"
:
[
"local"
,
"local"
,
"local"
,
"local"
],
"pad_token_id"
:
0
,
"eos_token_id"
:
2
,
"scope"
:
None
,
"hash_seed"
:
0
,
"num_labels"
:
2
,
}
def
setUp
(
self
):
tester_kwargs
=
self
.
prepare_kwargs
()
self
.
model_tester
=
ReformerModelTester
(
self
,
**
tester_kwargs
)
self
.
model_tester
=
ReformerModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
ReformerConfig
,
hidden_size
=
37
)
@
slow
...
...
@@ -716,49 +681,46 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation
test_headmasking
=
False
test_torchscript
=
False
def
prepare_kwargs
(
self
):
return
{
"batch_size"
:
13
,
"seq_length"
:
13
,
"use_input_mask"
:
True
,
"use_labels"
:
True
,
"is_training"
:
False
,
"is_decoder"
:
True
,
"vocab_size"
:
32
,
"attention_head_size"
:
16
,
"hidden_size"
:
64
,
"num_attention_heads"
:
2
,
"num_buckets"
:
2
,
"num_hashes"
:
4
,
"lsh_attn_chunk_length"
:
4
,
"lsh_num_chunks_before"
:
1
,
"lsh_num_chunks_after"
:
0
,
"chunk_size_lm_head"
:
5
,
"chunk_size_feed_forward"
:
6
,
"feed_forward_size"
:
32
,
"hidden_act"
:
"relu"
,
"hidden_dropout_prob"
:
0.1
,
"lsh_attention_probs_dropout_prob"
:
0.1
,
"max_position_embeddings"
:
512
,
"initializer_range"
:
0.02
,
"axial_norm_std"
:
1.0
,
"layer_norm_eps"
:
1e-12
,
"axial_pos_embds"
:
True
,
"axial_pos_shape"
:
[
4
,
8
],
"axial_pos_embds_dim"
:
[
16
,
48
],
# sanotheu
# "attn_layers": ["lsh", "lsh", "lsh", "lsh"],
"attn_layers"
:
[
"lsh"
],
"pad_token_id"
:
0
,
"eos_token_id"
:
2
,
"scope"
:
None
,
"hash_seed"
:
0
,
"num_labels"
:
2
,
}
def
setUp
(
self
):
tester_kwargs
=
self
.
prepare_kwargs
()
self
.
model_tester
=
ReformerModelTester
(
self
,
**
tester_kwargs
)
self
.
model_tester
=
ReformerModelTester
(
self
,
batch_size
=
13
,
seq_length
=
13
,
use_input_mask
=
True
,
use_labels
=
True
,
is_training
=
False
,
is_decoder
=
True
,
vocab_size
=
32
,
attention_head_size
=
16
,
hidden_size
=
64
,
num_attention_heads
=
2
,
num_buckets
=
2
,
num_hashes
=
4
,
lsh_attn_chunk_length
=
4
,
lsh_num_chunks_before
=
1
,
lsh_num_chunks_after
=
0
,
chunk_size_lm_head
=
5
,
chunk_size_feed_forward
=
6
,
feed_forward_size
=
32
,
hidden_act
=
"relu"
,
hidden_dropout_prob
=
0.1
,
lsh_attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
initializer_range
=
0.02
,
axial_norm_std
=
1.0
,
layer_norm_eps
=
1e-12
,
axial_pos_embds
=
True
,
axial_pos_shape
=
[
4
,
8
],
axial_pos_embds_dim
=
[
16
,
48
],
# sanotheu
# attn_layers=[lsh,lsh,lsh,lsh],
attn_layers
=
[
"lsh"
],
pad_token_id
=
0
,
eos_token_id
=
2
,
scope
=
None
,
hash_seed
=
0
,
num_labels
=
2
,
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
ReformerConfig
,
hidden_size
=
37
)
def
_check_attentions_for_generate
(
...
...
tests/test_modeling_roberta.py
View file @
c3d9ac76
...
...
@@ -17,7 +17,7 @@
import
unittest
from
copy
import
deepcopy
from
transformers
import
is_torch_available
from
transformers
import
RobertaConfig
,
is_torch_available
from
transformers.testing_utils
import
TestCasePlus
,
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -29,7 +29,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
RobertaConfig
,
RobertaForCausalLM
,
RobertaForMaskedLM
,
RobertaForMultipleChoice
,
...
...
@@ -94,7 +93,12 @@ class RobertaModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
RobertaConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
RobertaConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -108,8 +112,6 @@ class RobertaModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/test_modeling_roformer.py
View file @
c3d9ac76
...
...
@@ -18,7 +18,7 @@
import
unittest
from
tests.test_modeling_common
import
floats_tensor
from
transformers
import
is_torch_available
from
transformers
import
RoFormerConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -29,7 +29,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
RoFormerConfig
,
RoFormerForCausalLM
,
RoFormerForMaskedLM
,
RoFormerForMultipleChoice
,
...
...
@@ -113,7 +112,12 @@ class RoFormerModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
RoFormerConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
RoFormerConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -128,8 +132,6 @@ class RoFormerModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/test_modeling_speech_to_text.py
View file @
c3d9ac76
...
...
@@ -14,13 +14,13 @@
# limitations under the License.
""" Testing suite for the PyTorch Speech2Text model. """
import
copy
import
inspect
import
os
import
tempfile
import
unittest
from
transformers
import
Speech2TextConfig
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
(
is_torch_available
,
...
...
@@ -40,12 +40,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_te
if
is_torch_available
():
import
torch
from
transformers
import
(
Speech2TextConfig
,
Speech2TextForConditionalGeneration
,
Speech2TextModel
,
Speech2TextProcessor
,
)
from
transformers
import
Speech2TextForConditionalGeneration
,
Speech2TextModel
,
Speech2TextProcessor
from
transformers.models.speech_to_text.modeling_speech_to_text
import
Speech2TextDecoder
,
Speech2TextEncoder
...
...
@@ -142,7 +137,17 @@ class Speech2TextModelTester:
attention_mask
=
torch
.
ones
([
self
.
batch_size
,
self
.
seq_length
],
dtype
=
torch
.
long
,
device
=
torch_device
)
decoder_input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
).
clamp
(
2
)
config
=
Speech2TextConfig
(
config
=
self
.
get_config
()
inputs_dict
=
prepare_speech_to_text_inputs_dict
(
config
,
input_features
=
input_features
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
)
return
config
,
inputs_dict
def
get_config
(
self
):
return
Speech2TextConfig
(
vocab_size
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
encoder_layers
=
self
.
num_hidden_layers
,
...
...
@@ -165,13 +170,6 @@ class Speech2TextModelTester:
bos_token_id
=
self
.
bos_token_id
,
pad_token_id
=
self
.
pad_token_id
,
)
inputs_dict
=
prepare_speech_to_text_inputs_dict
(
config
,
input_features
=
input_features
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
)
return
config
,
inputs_dict
def
prepare_config_and_inputs_for_common
(
self
):
config
,
inputs_dict
=
self
.
prepare_config_and_inputs
()
...
...
tests/test_modeling_squeezebert.py
View file @
c3d9ac76
...
...
@@ -16,7 +16,7 @@
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
SqueezeBertConfig
,
is_torch_available
from
transformers.testing_utils
import
require_sentencepiece
,
require_tokenizers
,
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -28,7 +28,6 @@ if is_torch_available():
from
transformers
import
(
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
,
SqueezeBertConfig
,
SqueezeBertForMaskedLM
,
SqueezeBertForMultipleChoice
,
SqueezeBertForQuestionAnswering
,
...
...
@@ -37,179 +36,181 @@ if is_torch_available():
SqueezeBertModel
,
)
class
SqueezeBertModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_input_mask
=
True
,
use_token_type_ids
=
False
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
64
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
,
q_groups
=
2
,
k_groups
=
2
,
v_groups
=
2
,
post_attention_groups
=
2
,
intermediate_groups
=
4
,
output_groups
=
1
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_input_mask
=
use_input_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
q_groups
=
q_groups
self
.
k_groups
=
k_groups
self
.
v_groups
=
v_groups
self
.
post_attention_groups
=
post_attention_groups
self
.
intermediate_groups
=
intermediate_groups
self
.
output_groups
=
output_groups
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
SqueezeBertConfig
(
embedding_size
=
self
.
hidden_size
,
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
intermediate_size
,
hidden_act
=
self
.
hidden_act
,
attention_probs_dropout_prob
=
self
.
hidden_dropout_prob
,
attention_dropout
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
initializer_range
=
self
.
initializer_range
,
q_groups
=
self
.
q_groups
,
k_groups
=
self
.
k_groups
,
v_groups
=
self
.
v_groups
,
post_attention_groups
=
self
.
post_attention_groups
,
intermediate_groups
=
self
.
intermediate_groups
,
output_groups
=
self
.
output_groups
,
)
return
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_squeezebert_model
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
SqueezeBertModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
input_mask
)
result
=
model
(
input_ids
)
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
)
)
def
create_and_check_squeezebert_for_masked_lm
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
SqueezeBertForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
))
def
create_and_check_squeezebert_for_question_answering
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
SqueezeBertForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
self
.
parent
.
assertEqual
(
result
.
start_logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
))
self
.
parent
.
assertEqual
(
result
.
end_logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
))
def
create_and_check_squeezebert_for_sequence_classification
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
SqueezeBertForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
sequence_labels
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_labels
))
def
create_and_check_squeezebert_for_token_classification
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
SqueezeBertForTokenClassification
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
))
def
create_and_check_squeezebert_for_multiple_choice
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_choices
=
self
.
num_choices
model
=
SqueezeBertForMultipleChoice
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
result
=
model
(
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
labels
=
choice_labels
,
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_choices
))
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
input_mask
}
return
config
,
inputs_dict
class
SqueezeBertModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_input_mask
=
True
,
use_token_type_ids
=
False
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
64
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
,
q_groups
=
2
,
k_groups
=
2
,
v_groups
=
2
,
post_attention_groups
=
2
,
intermediate_groups
=
4
,
output_groups
=
1
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_input_mask
=
use_input_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
q_groups
=
q_groups
self
.
k_groups
=
k_groups
self
.
v_groups
=
v_groups
self
.
post_attention_groups
=
post_attention_groups
self
.
intermediate_groups
=
intermediate_groups
self
.
output_groups
=
output_groups
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
self
.
get_config
()
return
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
SqueezeBertConfig
(
embedding_size
=
self
.
hidden_size
,
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
intermediate_size
,
hidden_act
=
self
.
hidden_act
,
attention_probs_dropout_prob
=
self
.
hidden_dropout_prob
,
attention_dropout
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
initializer_range
=
self
.
initializer_range
,
q_groups
=
self
.
q_groups
,
k_groups
=
self
.
k_groups
,
v_groups
=
self
.
v_groups
,
post_attention_groups
=
self
.
post_attention_groups
,
intermediate_groups
=
self
.
intermediate_groups
,
output_groups
=
self
.
output_groups
,
)
def
create_and_check_squeezebert_model
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
SqueezeBertModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
input_mask
)
result
=
model
(
input_ids
)
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
))
def
create_and_check_squeezebert_for_masked_lm
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
SqueezeBertForMaskedLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
))
def
create_and_check_squeezebert_for_question_answering
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
SqueezeBertForQuestionAnswering
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
self
.
parent
.
assertEqual
(
result
.
start_logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
))
self
.
parent
.
assertEqual
(
result
.
end_logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
))
def
create_and_check_squeezebert_for_sequence_classification
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
SqueezeBertForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
sequence_labels
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_labels
))
def
create_and_check_squeezebert_for_token_classification
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
SqueezeBertForTokenClassification
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
))
def
create_and_check_squeezebert_for_multiple_choice
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_choices
=
self
.
num_choices
model
=
SqueezeBertForMultipleChoice
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
result
=
model
(
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
labels
=
choice_labels
,
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_choices
))
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
input_mask
}
return
config
,
inputs_dict
@
require_torch
...
...
tests/test_modeling_t5.py
View file @
c3d9ac76
...
...
@@ -18,7 +18,7 @@ import copy
import
tempfile
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
T5Config
,
is_torch_available
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_sentencepiece
,
require_tokenizers
,
require_torch
,
slow
,
torch_device
...
...
@@ -30,7 +30,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
if
is_torch_available
():
import
torch
from
transformers
import
ByT5Tokenizer
,
T5Config
,
T5EncoderModel
,
T5ForConditionalGeneration
,
T5Model
,
T5Tokenizer
from
transformers
import
ByT5Tokenizer
,
T5EncoderModel
,
T5ForConditionalGeneration
,
T5Model
,
T5Tokenizer
from
transformers.models.t5.modeling_t5
import
T5_PRETRAINED_MODEL_ARCHIVE_LIST
...
...
@@ -100,7 +100,19 @@ class T5ModelTester:
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
decoder_seq_length
],
self
.
vocab_size
)
config
=
T5Config
(
config
=
self
.
get_config
()
return
(
config
,
input_ids
,
decoder_input_ids
,
attention_mask
,
decoder_attention_mask
,
lm_labels
,
)
def
get_config
(
self
):
return
T5Config
(
vocab_size
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
d_ff
=
self
.
d_ff
,
...
...
@@ -117,15 +129,6 @@ class T5ModelTester:
decoder_start_token_id
=
self
.
decoder_start_token_id
,
)
return
(
config
,
input_ids
,
decoder_input_ids
,
attention_mask
,
decoder_attention_mask
,
lm_labels
,
)
def
check_prepare_lm_labels_via_shift_left
(
self
,
config
,
...
...
tests/test_modeling_tapas.py
View file @
c3d9ac76
...
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
unittest
...
...
@@ -29,6 +28,7 @@ from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
TapasConfig
,
is_torch_available
,
)
from
transformers.file_utils
import
cached_property
...
...
@@ -43,7 +43,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
TapasConfig
,
TapasForMaskedLM
,
TapasForQuestionAnswering
,
TapasForSequenceClassification
,
...
...
@@ -183,7 +182,24 @@ class TapasModelTester:
float_answer
=
floats_tensor
([
self
.
batch_size
]).
to
(
torch_device
)
aggregation_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_aggregation_labels
).
to
(
torch_device
)
config
=
TapasConfig
(
config
=
self
.
get_config
()
return
(
config
,
input_ids
,
input_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
labels
,
numeric_values
,
numeric_values_scale
,
float_answer
,
aggregation_labels
,
)
def
get_config
(
self
):
return
TapasConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -220,20 +236,6 @@ class TapasModelTester:
disable_per_token_loss
=
self
.
disable_per_token_loss
,
)
return
(
config
,
input_ids
,
input_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
labels
,
numeric_values
,
numeric_values_scale
,
float_answer
,
aggregation_labels
,
)
def
create_and_check_model
(
self
,
config
,
...
...
tests/test_modeling_transfo_xl.py
View file @
c3d9ac76
...
...
@@ -17,7 +17,7 @@ import copy
import
random
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
TransfoXLConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
require_torch_multi_gpu
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -29,7 +29,7 @@ if is_torch_available():
import
torch
from
torch
import
nn
from
transformers
import
TransfoXLConfig
,
TransfoXLForSequenceClassification
,
TransfoXLLMHeadModel
,
TransfoXLModel
from
transformers
import
TransfoXLForSequenceClassification
,
TransfoXLLMHeadModel
,
TransfoXLModel
from
transformers.models.transfo_xl.modeling_transfo_xl
import
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
...
...
@@ -69,7 +69,12 @@ class TransfoXLModelTester:
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
TransfoXLConfig
(
config
=
self
.
get_config
()
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
def
get_config
(
self
):
return
TransfoXLConfig
(
vocab_size
=
self
.
vocab_size
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
...
...
@@ -85,8 +90,6 @@ class TransfoXLModelTester:
pad_token_id
=
self
.
pad_token_id
,
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
...
...
tests/test_modeling_visual_bert.py
View file @
c3d9ac76
...
...
@@ -14,12 +14,11 @@
# limitations under the License.
""" Testing suite for the PyTorch VisualBERT model. """
import
copy
import
unittest
from
tests.test_modeling_common
import
floats_tensor
from
transformers
import
is_torch_available
from
transformers
import
VisualBertConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -30,7 +29,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
VisualBertConfig
,
VisualBertForMultipleChoice
,
VisualBertForPreTraining
,
VisualBertForQuestionAnswering
,
...
...
@@ -98,7 +96,7 @@ class VisualBertModelTester:
self
.
num_choices
=
num_choices
self
.
scope
=
scope
def
prepare
_config
(
self
):
def
get
_config
(
self
):
return
VisualBertConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
...
...
@@ -138,7 +136,7 @@ class VisualBertModelTester:
if
self
.
use_visual_token_type_ids
:
visual_token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
visual_seq_length
],
self
.
type_vocab_size
)
config
=
self
.
prepare
_config
()
config
=
self
.
get
_config
()
return
config
,
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
...
...
@@ -198,7 +196,7 @@ class VisualBertModelTester:
if
self
.
use_labels
:
labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
self
.
prepare
_config
()
config
=
self
.
get
_config
()
return
config
,
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
...
...
tests/test_modeling_vit.py
View file @
c3d9ac76
...
...
@@ -18,6 +18,7 @@
import
inspect
import
unittest
from
transformers
import
ViTConfig
from
transformers.file_utils
import
cached_property
,
is_torch_available
,
is_vision_available
from
transformers.testing_utils
import
require_torch
,
require_vision
,
slow
,
torch_device
...
...
@@ -29,7 +30,7 @@ if is_torch_available():
import
torch
from
torch
import
nn
from
transformers
import
ViTConfig
,
ViTForImageClassification
,
ViTModel
from
transformers
import
ViTForImageClassification
,
ViTModel
from
transformers.models.vit.modeling_vit
import
VIT_PRETRAINED_MODEL_ARCHIVE_LIST
,
to_2tuple
...
...
@@ -86,7 +87,12 @@ class ViTModelTester:
if
self
.
use_labels
:
labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
config
=
ViTConfig
(
config
=
self
.
get_config
()
return
config
,
pixel_values
,
labels
def
get_config
(
self
):
return
ViTConfig
(
image_size
=
self
.
image_size
,
patch_size
=
self
.
patch_size
,
num_channels
=
self
.
num_channels
,
...
...
@@ -101,8 +107,6 @@ class ViTModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
pixel_values
,
labels
def
create_and_check_model
(
self
,
config
,
pixel_values
,
labels
):
model
=
ViTModel
(
config
=
config
)
model
.
to
(
torch_device
)
...
...
tests/test_modeling_wav2vec2.py
View file @
c3d9ac76
...
...
@@ -21,7 +21,7 @@ import unittest
import
pytest
from
tests.test_modeling_common
import
floats_tensor
,
ids_tensor
,
random_attention_mask
from
transformers
import
is_torch_available
from
transformers
import
Wav2Vec2Config
,
is_torch_available
from
transformers.testing_utils
import
require_datasets
,
require_soundfile
,
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -32,7 +32,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
Wav2Vec2Config
,
Wav2Vec2FeatureExtractor
,
Wav2Vec2ForCTC
,
Wav2Vec2ForMaskedLM
,
...
...
@@ -106,7 +105,12 @@ class Wav2Vec2ModelTester:
input_values
=
floats_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
attention_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
config
=
Wav2Vec2Config
(
config
=
self
.
get_config
()
return
config
,
input_values
,
attention_mask
def
get_config
(
self
):
return
Wav2Vec2Config
(
hidden_size
=
self
.
hidden_size
,
feat_extract_norm
=
self
.
feat_extract_norm
,
feat_extract_dropout
=
self
.
feat_extract_dropout
,
...
...
@@ -127,8 +131,6 @@ class Wav2Vec2ModelTester:
vocab_size
=
self
.
vocab_size
,
)
return
config
,
input_values
,
attention_mask
def
create_and_check_model
(
self
,
config
,
input_values
,
attention_mask
):
model
=
Wav2Vec2Model
(
config
=
config
)
model
.
to
(
torch_device
)
...
...
tests/test_modeling_xlm.py
View file @
c3d9ac76
...
...
@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
XLMConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -28,7 +27,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
XLMConfig
,
XLMForMultipleChoice
,
XLMForQuestionAnswering
,
XLMForQuestionAnsweringSimple
,
...
...
@@ -97,7 +95,22 @@ class XLMModelTester:
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
XLMConfig
(
config
=
self
.
get_config
()
return
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
choice_labels
,
input_mask
,
)
def
get_config
(
self
):
return
XLMConfig
(
vocab_size
=
self
.
vocab_size
,
n_special
=
self
.
n_special
,
emb_dim
=
self
.
hidden_size
,
...
...
@@ -118,18 +131,6 @@ class XLMModelTester:
bos_token_id
=
self
.
bos_token_id
,
)
return
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
choice_labels
,
input_mask
,
)
def
create_and_check_xlm_model
(
self
,
config
,
...
...
tests/test_modeling_xlnet.py
View file @
c3d9ac76
...
...
@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
XLNetConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -29,7 +28,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
XLNetConfig
,
XLNetForMultipleChoice
,
XLNetForQuestionAnswering
,
XLNetForQuestionAnsweringSimple
,
...
...
@@ -131,7 +129,25 @@ class XLNetModelTester:
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
config
=
XLNetConfig
(
config
=
self
.
get_config
()
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
)
def
get_config
(
self
):
return
XLNetConfig
(
vocab_size
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
n_head
=
self
.
num_attention_heads
,
...
...
@@ -150,21 +166,6 @@ class XLNetModelTester:
eos_token_id
=
self
.
eos_token_id
,
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment