Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c852036b
Unverified
Commit
c852036b
authored
Jun 16, 2020
by
Amil Khare
Committed by
GitHub
Jun 16, 2020
Browse files
[cleanup] Hoist ModelTester objects to top level (#4939)
Co-authored-by:
Sam Shleifer
<
sshleifer@gmail.com
>
parent
0c55a384
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1297 additions
and
1410 deletions
+1297
-1410
tests/test_modeling_tf_xlm.py
tests/test_modeling_tf_xlm.py
+198
-229
tests/test_modeling_tf_xlnet.py
tests/test_modeling_tf_xlnet.py
+275
-298
tests/test_modeling_transfo_xl.py
tests/test_modeling_transfo_xl.py
+132
-150
tests/test_modeling_xlm.py
tests/test_modeling_xlm.py
+294
-331
tests/test_modeling_xlnet.py
tests/test_modeling_xlnet.py
+398
-402
No files found.
tests/test_modeling_tf_xlm.py
View file @
c852036b
...
...
@@ -35,81 +35,39 @@ if is_tf_available():
)
@
require_tf
class
TFXLMModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
)
if
is_tf_available
()
else
()
)
all_generative_model_classes
=
(
(
TFXLMWithLMHeadModel
,)
if
is_tf_available
()
else
()
)
# TODO (PVP): Check other models whether language generation is also applicable
class
TFXLMModelTester
(
object
):
class
TFXLMModelTester
:
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_input_lengths
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
gelu_activation
=
True
,
sinusoidal_embeddings
=
False
,
causal
=
False
,
asm
=
False
,
n_langs
=
2
,
vocab_size
=
99
,
n_special
=
0
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
summary_type
=
"last"
,
use_proj
=
True
,
scope
=
None
,
bos_token_id
=
0
,
self
,
parent
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_input_lengths
=
use_input_lengths
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
gelu_activation
=
gelu_activation
self
.
sinusoidal_embeddings
=
sinusoidal_embeddings
self
.
asm
=
asm
self
.
n_langs
=
n_langs
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
summary_type
=
summary_type
self
.
causal
=
causal
self
.
use_proj
=
use_proj
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
n_langs
=
n_langs
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
summary_type
=
summary_type
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
bos_token_id
=
bos_token_id
self
.
batch_size
=
13
self
.
seq_length
=
7
self
.
is_training
=
True
self
.
use_input_lengths
=
True
self
.
use_token_type_ids
=
True
self
.
use_labels
=
True
self
.
gelu_activation
=
True
self
.
sinusoidal_embeddings
=
False
self
.
causal
=
False
self
.
asm
=
False
self
.
n_langs
=
2
self
.
vocab_size
=
99
self
.
n_special
=
0
self
.
hidden_size
=
32
self
.
num_hidden_layers
=
5
self
.
num_attention_heads
=
4
self
.
hidden_dropout_prob
=
0.1
self
.
attention_probs_dropout_prob
=
0.1
self
.
max_position_embeddings
=
512
self
.
type_vocab_size
=
16
self
.
type_sequence_label_size
=
2
self
.
initializer_range
=
0.02
self
.
num_labels
=
3
self
.
num_choices
=
4
self
.
summary_type
=
"last"
self
.
use_proj
=
True
self
.
scope
=
None
self
.
bos_token_id
=
0
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -211,9 +169,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_xlm_qa
(
self
,
...
...
@@ -283,8 +239,21 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
}
return
config
,
inputs_dict
@
require_tf
class
TFXLMModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
)
if
is_tf_available
()
else
()
)
all_generative_model_classes
=
(
(
TFXLMWithLMHeadModel
,)
if
is_tf_available
()
else
()
)
# TODO (PVP): Check other models whether language generation is also applicable
def
setUp
(
self
):
self
.
model_tester
=
TFXLMModelTest
.
TFXLMModelTester
(
self
)
self
.
model_tester
=
TFXLMModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLMConfig
,
emb_dim
=
37
)
def
test_config
(
self
):
...
...
tests/test_modeling_tf_xlnet.py
View file @
c852036b
...
...
@@ -37,78 +37,35 @@ if is_tf_available():
)
@
require_tf
class
TFXLNetModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
TFXLNetForTokenClassification
,
TFXLNetForQuestionAnsweringSimple
,
)
if
is_tf_available
()
else
()
)
all_generative_model_classes
=
(
(
TFXLNetLMHeadModel
,)
if
is_tf_available
()
else
()
)
# TODO (PVP): Check other models whether language generation is also applicable
test_pruning
=
False
class
TFXLNetModelTester
(
object
):
class
TFXLNetModelTester
:
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
10
,
clamp_len
=-
1
,
reuse_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
hidden_size
=
32
,
num_attention_heads
=
4
,
d_inner
=
128
,
num_hidden_layers
=
5
,
type_sequence_label_size
=
2
,
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
initializer_range
=
0.05
,
seed
=
1
,
type_vocab_size
=
2
,
bos_token_id
=
1
,
eos_token_id
=
2
,
pad_token_id
=
5
,
self
,
parent
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
batch_size
=
13
self
.
seq_length
=
7
self
.
mem_len
=
10
# self.key_len = seq_length + mem_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
d_inner
=
d_inner
self
.
num_hidden_layers
=
num_hidden_layers
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
initializer_range
=
initializer_rang
e
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_
sequence_label_size
=
type_sequence_label_size
self
.
bos_token_id
=
bos_token_id
self
.
pad
_token_id
=
pad_token_id
self
.
eos
_token_id
=
eos_token_id
self
.
clamp_len
=
-
1
self
.
reuse_len
=
15
self
.
is_training
=
True
self
.
use_labels
=
True
self
.
vocab_size
=
99
self
.
cutoffs
=
[
10
,
50
,
80
]
self
.
hidden_size
=
32
self
.
num_attention_heads
=
4
self
.
d_inner
=
128
self
.
num_hidden_layers
=
5
self
.
type_sequence_label_size
=
2
self
.
untie_r
=
True
self
.
bi_data
=
False
self
.
same_length
=
Fals
e
self
.
initializer_range
=
0.05
self
.
seed
=
1
self
.
type_
vocab_size
=
2
self
.
bos_token_id
=
1
self
.
eos
_token_id
=
2
self
.
pad
_token_id
=
5
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -377,8 +334,28 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
inputs_dict
=
{
"input_ids"
:
input_ids_1
}
return
config
,
inputs_dict
@
require_tf
class
TFXLNetModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
TFXLNetForTokenClassification
,
TFXLNetForQuestionAnsweringSimple
,
)
if
is_tf_available
()
else
()
)
all_generative_model_classes
=
(
(
TFXLNetLMHeadModel
,)
if
is_tf_available
()
else
()
)
# TODO (PVP): Check other models whether language generation is also applicable
test_pruning
=
False
def
setUp
(
self
):
self
.
model_tester
=
TFXLNetModelTest
.
TFXLNetModelTester
(
self
)
self
.
model_tester
=
TFXLNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
def
test_config
(
self
):
...
...
tests/test_modeling_transfo_xl.py
View file @
c852036b
...
...
@@ -29,58 +29,30 @@ if is_torch_available():
from
transformers.modeling_transfo_xl
import
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
@
require_torch
class
TransfoXLModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
TransfoXLModel
,
TransfoXLLMHeadModel
)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
TransfoXLLMHeadModel
,)
if
is_torch_available
()
else
()
test_pruning
=
False
test_torchscript
=
False
test_resize_embeddings
=
True
class
TransfoXLModelTester
(
object
):
class
TransfoXLModelTester
:
def
__init__
(
self
,
parent
,
batch_size
=
14
,
seq_length
=
7
,
mem_len
=
30
,
clamp_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
hidden_size
=
32
,
d_embed
=
32
,
num_attention_heads
=
4
,
d_head
=
8
,
d_inner
=
128
,
div_val
=
2
,
num_hidden_layers
=
5
,
scope
=
None
,
seed
=
1
,
eos_token_id
=
0
,
self
,
parent
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
key_length
=
seq_length
+
mem_len
self
.
clamp_len
=
clamp_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
hidden_size
=
hidden_size
self
.
d_embed
=
d_embed
self
.
num_attention_heads
=
num_attention_heads
self
.
d_head
=
d_head
self
.
d_inner
=
d_inner
self
.
div_val
=
div_val
self
.
num_hidden_layers
=
num_hidden_layers
self
.
scope
=
scop
e
self
.
seed
=
seed
self
.
eos_token_id
=
eos_token_id
self
.
batch_size
=
14
self
.
seq_length
=
7
self
.
mem_len
=
30
self
.
key_length
=
self
.
seq_length
+
self
.
mem_len
self
.
clamp_len
=
15
self
.
is_training
=
True
self
.
use_labels
=
True
self
.
vocab_size
=
99
self
.
cutoffs
=
[
10
,
50
,
80
]
self
.
hidden_size
=
32
self
.
d_embed
=
32
self
.
num_attention_heads
=
4
self
.
d_head
=
8
self
.
d_inner
=
128
self
.
div_val
=
2
self
.
num_hidden_layers
=
5
self
.
scope
=
Non
e
self
.
seed
=
1
self
.
eos_token_id
=
0
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -187,6 +159,16 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict
=
{
"input_ids"
:
input_ids_1
}
return
config
,
inputs_dict
@
require_torch
class
TransfoXLModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
TransfoXLModel
,
TransfoXLLMHeadModel
)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
TransfoXLLMHeadModel
,)
if
is_torch_available
()
else
()
test_pruning
=
False
test_torchscript
=
False
test_resize_embeddings
=
True
def
check_cutoffs_and_n_token
(
self
,
copied_cutoffs
,
layer
,
model_embed
,
model
,
model_class
,
resized_value
,
vocab_size
):
...
...
@@ -210,7 +192,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
model
.
crit
.
n_token
,
vocab_size
+
resized_value
)
def
setUp
(
self
):
self
.
model_tester
=
TransfoXLModelTest
.
TransfoXLModelTester
(
self
)
self
.
model_tester
=
TransfoXLModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
def
test_config
(
self
):
...
...
tests/test_modeling_xlm.py
View file @
c852036b
...
...
@@ -37,87 +37,38 @@ if is_torch_available():
from
transformers.modeling_xlm
import
XLM_PRETRAINED_MODEL_ARCHIVE_LIST
@
require_torch
class
XLMModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
,
XLMForQuestionAnsweringSimple
,
)
if
is_torch_available
()
else
()
)
all_generative_model_classes
=
(
(
XLMWithLMHeadModel
,)
if
is_torch_available
()
else
()
)
# TODO (PVP): Check other models whether language generation is also applicable
class
XLMModelTester
(
object
):
class
XLMModelTester
:
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_input_lengths
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
gelu_activation
=
True
,
sinusoidal_embeddings
=
False
,
causal
=
False
,
asm
=
False
,
n_langs
=
2
,
vocab_size
=
99
,
n_special
=
0
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
summary_type
=
"last"
,
use_proj
=
True
,
scope
=
None
,
bos_token_id
=
0
,
self
,
parent
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_input_lengths
=
use_input_lengths
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
gelu_activation
=
gelu_activation
self
.
sinusoidal_embeddings
=
sinusoidal_embeddings
self
.
asm
=
asm
self
.
n_langs
=
n_langs
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
summary_type
=
summary_type
self
.
causal
=
causal
self
.
use_proj
=
use_proj
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
n_langs
=
n_langs
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
summary_type
=
summary_type
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
bos_token_id
=
bos_token_id
self
.
batch_size
=
13
self
.
seq_length
=
7
self
.
is_training
=
True
self
.
use_input_lengths
=
True
self
.
use_token_type_ids
=
True
self
.
use_labels
=
True
self
.
gelu_activation
=
True
self
.
sinusoidal_embeddings
=
False
self
.
causal
=
False
self
.
asm
=
False
self
.
n_langs
=
2
self
.
vocab_size
=
99
self
.
n_special
=
0
self
.
hidden_size
=
32
self
.
num_hidden_layers
=
5
self
.
num_attention_heads
=
4
self
.
hidden_dropout_prob
=
0.1
self
.
attention_probs_dropout_prob
=
0.1
self
.
max_position_embeddings
=
512
self
.
type_sequence_label_size
=
2
self
.
initializer_range
=
0.02
self
.
num_labels
=
3
self
.
num_choices
=
4
self
.
summary_type
=
"last"
self
.
use_proj
=
True
self
.
scope
=
None
self
.
bos_token_id
=
0
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -223,9 +174,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_xlm_simple_qa
(
self
,
...
...
@@ -318,8 +267,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
...
...
@@ -347,9 +295,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
def
create_and_check_xlm_for_token_classification
(
self
,
...
...
@@ -372,9 +318,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
def
prepare_config_and_inputs_for_common
(
self
):
...
...
@@ -392,8 +336,27 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"lengths"
:
input_lengths
}
return
config
,
inputs_dict
@
require_torch
class
XLMModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
,
XLMForQuestionAnsweringSimple
,
)
if
is_torch_available
()
else
()
)
all_generative_model_classes
=
(
(
XLMWithLMHeadModel
,)
if
is_torch_available
()
else
()
)
# TODO (PVP): Check other models whether language generation is also applicable
def
setUp
(
self
):
self
.
model_tester
=
XLMModelTest
.
XLMModelTester
(
self
)
self
.
model_tester
=
XLMModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLMConfig
,
emb_dim
=
37
)
def
test_config
(
self
):
...
...
tests/test_modeling_xlnet.py
View file @
c852036b
...
...
@@ -39,27 +39,7 @@ if is_torch_available():
from
transformers.modeling_xlnet
import
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST
@
require_torch
class
XLNetModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForTokenClassification
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
,
XLNetForMultipleChoice
,
)
if
is_torch_available
()
else
()
)
all_generative_model_classes
=
(
(
XLNetLMHeadModel
,)
if
is_torch_available
()
else
()
)
# TODO (PVP): Check other models whether language generation is also applicable
test_pruning
=
False
class
XLNetModelTester
(
object
):
class
XLNetModelTester
:
def
__init__
(
self
,
parent
,
...
...
@@ -89,31 +69,31 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
num_choices
=
4
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
batch_size
=
14
self
.
seq_length
=
7
self
.
mem_len
=
10
# self.key_len = seq_length + mem_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
d_inner
=
d_inner
self
.
num_hidden_layers
=
num_hidden_layers
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
initializer_range
=
initializer_rang
e
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_
sequence_label_size
=
type_sequence_label_size
self
.
bos_token_id
=
bos_token_id
self
.
pad
_token_id
=
pad_token_id
self
.
eos
_token_id
=
eos_token_id
self
.
num_choices
=
num_choices
self
.
clamp_len
=
-
1
self
.
reuse_len
=
15
self
.
is_training
=
True
self
.
use_labels
=
True
self
.
vocab_size
=
99
self
.
cutoffs
=
[
10
,
50
,
80
]
self
.
hidden_size
=
32
self
.
num_attention_heads
=
4
self
.
d_inner
=
128
self
.
num_hidden_layers
=
5
self
.
type_sequence_label_size
=
2
self
.
untie_r
=
True
self
.
bi_data
=
False
self
.
same_length
=
Fals
e
self
.
initializer_range
=
0.05
self
.
seed
=
1
self
.
type_
vocab_size
=
2
self
.
bos_token_id
=
1
self
.
eos
_token_id
=
2
self
.
pad
_token_id
=
5
self
.
num_choices
=
4
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -126,9 +106,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
,
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
,
)
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
,)
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
sequence_labels
=
None
...
...
@@ -270,9 +248,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
loss_1
,
all_logits_1
,
mems_1
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
)
loss_2
,
all_logits_2
,
mems_2
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
,
mems
=
mems_1
)
loss_2
,
all_logits_2
,
mems_2
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
,
mems
=
mems_1
)
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
)
...
...
@@ -370,8 +346,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
...
...
@@ -472,8 +447,29 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict
=
{
"input_ids"
:
input_ids_1
}
return
config
,
inputs_dict
@
require_torch
class
XLNetModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForTokenClassification
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
,
XLNetForMultipleChoice
,
)
if
is_torch_available
()
else
()
)
all_generative_model_classes
=
(
(
XLNetLMHeadModel
,)
if
is_torch_available
()
else
()
)
# TODO (PVP): Check other models whether language generation is also applicable
test_pruning
=
False
def
setUp
(
self
):
self
.
model_tester
=
XLNetModelTest
.
XLNetModelTester
(
self
)
self
.
model_tester
=
XLNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
def
test_config
(
self
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment