Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
808bb8da
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "455c6390938a5c737fa63e78396cedae41e4e87e"
Commit
808bb8da
authored
Dec 09, 2019
by
thomwolf
Browse files
fix transfo xl tests
parent
b016dd16
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
8 deletions
+14
-8
transformers/tests/modeling_common_test.py
transformers/tests/modeling_common_test.py
+12
-6
transformers/tests/modeling_tf_transfo_xl_test.py
transformers/tests/modeling_tf_transfo_xl_test.py
+1
-1
transformers/tests/modeling_transfo_xl_test.py
transformers/tests/modeling_transfo_xl_test.py
+1
-1
No files found.
transformers/tests/modeling_common_test.py
View file @
808bb8da
...
@@ -125,6 +125,11 @@ class CommonTestCases:
...
@@ -125,6 +125,11 @@ class CommonTestCases:
def
test_attention_outputs
(
self
):
def
test_attention_outputs
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
decoder_seq_length
=
self
.
model_tester
.
decoder_seq_length
if
hasattr
(
self
.
model_tester
,
'decoder_seq_length'
)
else
self
.
model_tester
.
seq_length
encoder_seq_length
=
self
.
model_tester
.
encoder_seq_length
if
hasattr
(
self
.
model_tester
,
'encoder_seq_length'
)
else
self
.
model_tester
.
seq_length
decoder_key_length
=
self
.
model_tester
.
key_length
if
hasattr
(
self
.
model_tester
,
'key_length'
)
else
decoder_seq_length
encoder_key_length
=
self
.
model_tester
.
key_length
if
hasattr
(
self
.
model_tester
,
'key_length'
)
else
encoder_seq_length
for
model_class
in
self
.
all_model_classes
:
for
model_class
in
self
.
all_model_classes
:
config
.
output_attentions
=
True
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
config
.
output_hidden_states
=
False
...
@@ -138,8 +143,8 @@ class CommonTestCases:
...
@@ -138,8 +143,8 @@ class CommonTestCases:
self
.
assertListEqual
(
self
.
assertListEqual
(
list
(
attentions
[
0
].
shape
[
-
3
:]),
list
(
attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
encoder_seq_length
if
hasattr
(
self
.
model_tester
,
'encoder_seq_length'
)
else
self
.
model_tester
.
seq_length
,
encoder_seq_length
,
self
.
model_tester
.
encoder_seq_length
if
hasattr
(
self
.
model_tester
,
'encoder_seq_length'
)
else
self
.
model_tester
.
seq
_length
])
encoder_key
_length
])
out_len
=
len
(
outputs
)
out_len
=
len
(
outputs
)
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
...
@@ -151,8 +156,9 @@ class CommonTestCases:
...
@@ -151,8 +156,9 @@ class CommonTestCases:
self
.
assertListEqual
(
self
.
assertListEqual
(
list
(
decoder_attentions
[
0
].
shape
[
-
3
:]),
list
(
decoder_attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
decoder_seq_length
if
hasattr
(
self
.
model_tester
,
'decoder_seq_length'
)
else
self
.
model_tester
.
seq_length
,
decoder_seq_length
,
self
.
model_tester
.
decoder_seq_length
if
hasattr
(
self
.
model_tester
,
'decoder_seq_length'
)
else
self
.
model_tester
.
seq_length
])
decoder_key_length
])
# Check attention is always last and order is fine
# Check attention is always last and order is fine
config
.
output_attentions
=
True
config
.
output_attentions
=
True
...
@@ -169,8 +175,8 @@ class CommonTestCases:
...
@@ -169,8 +175,8 @@ class CommonTestCases:
self
.
assertListEqual
(
self
.
assertListEqual
(
list
(
self_attentions
[
0
].
shape
[
-
3
:]),
list
(
self_attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
encoder_seq_length
if
hasattr
(
self
.
model_tester
,
'encoder_seq_length'
)
else
self
.
model_tester
.
seq_length
,
encoder_
seq_length
,
self
.
model_tester
.
encoder_seq_length
if
hasattr
(
self
.
model_tester
,
'encoder_seq_length'
)
else
self
.
model_tester
.
seq
_length
])
encoder_key
_length
])
def
test_torchscript
(
self
):
def
test_torchscript
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
transformers/tests/modeling_tf_transfo_xl_test.py
View file @
808bb8da
...
@@ -68,7 +68,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -68,7 +68,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
mem_len
=
mem_len
self
.
key_len
=
seq_length
+
mem_len
self
.
key_len
gth
=
seq_length
+
mem_len
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
clamp_len
self
.
is_training
=
is_training
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
...
...
transformers/tests/modeling_transfo_xl_test.py
View file @
808bb8da
...
@@ -66,7 +66,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
...
@@ -66,7 +66,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
mem_len
=
mem_len
self
.
key_len
=
seq_length
+
mem_len
self
.
key_len
gth
=
seq_length
+
mem_len
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
clamp_len
self
.
is_training
=
is_training
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment