Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b4a3a647
Commit
b4a3a647
authored
Mar 08, 2020
by
patrickvonplaten
Browse files
fix xlnet & transfotests
parent
66c82765
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
64 deletions
+24
-64
tests/test_modeling_tf_transfo_xl.py
tests/test_modeling_tf_transfo_xl.py
+3
-13
tests/test_modeling_tf_xlnet.py
tests/test_modeling_tf_xlnet.py
+3
-13
tests/test_modeling_transfo_xl.py
tests/test_modeling_transfo_xl.py
+4
-13
tests/test_modeling_xlnet.py
tests/test_modeling_xlnet.py
+14
-25
No files found.
tests/test_modeling_tf_transfo_xl.py
View file @
b4a3a647
...
...
@@ -519,20 +519,10 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
24
,
24
,
0
,
29546
,
40
,
1092
,
18
,
8
,
5854
,
7
,
1143
,
2
,
7
,
33
,
1
,
159
,
99
,
16
,
1857
,
2
,
1
,
1009
,
4
,
...
...
tests/test_modeling_tf_xlnet.py
View file @
b4a3a647
...
...
@@ -760,20 +760,10 @@ class TFXLNetModelLanguageGenerationTest(unittest.TestCase):
9
,
4
,
3
,
1722
,
19
,
24
,
6348
,
61
,
977
,
176
,
1772
,
33
,
45
,
970
,
19
,
4185
,
19
,
12943
,
4354
,
153
,
27
,
442
,
22
,
...
...
tests/test_modeling_transfo_xl.py
View file @
b4a3a647
...
...
@@ -376,6 +376,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
# father initially slaps him for making such an accusation , Rasputin watches as the
# man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
# the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
# with people , even a bishop , begging for his blessing . <eod> </s> <eos>
expected_output_ids
=
[
...
...
@@ -520,20 +521,10 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
24
,
24
,
0
,
29546
,
40
,
1092
,
18
,
8
,
5854
,
7
,
1143
,
2
,
7
,
33
,
1
,
159
,
99
,
16
,
1857
,
2
,
1
,
1009
,
4
,
...
...
tests/test_modeling_xlnet.py
View file @
b4a3a647
...
...
@@ -119,11 +119,11 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
,
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
,
)
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
...
...
@@ -212,7 +212,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self
.
parent
.
assertEqual
(
len
(
no_mems_outputs
),
1
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"outputs"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
list
(
result
[
"outputs"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
,
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
...
...
@@ -283,7 +283,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"all_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
list
(
result
[
"all_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
,
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
...
...
@@ -292,7 +292,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"all_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
list
(
result
[
"all_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
,
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
...
...
@@ -319,7 +319,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
model
.
eval
()
outputs
=
model
(
input_ids_1
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
(
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
,)
=
outputs
outputs
=
model
(
input_ids_1
,
...
...
@@ -340,7 +340,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
total_loss
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
total_loss
,
mems
=
outputs
...
...
@@ -356,10 +356,10 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
,
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
,
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_log_probs"
].
size
()),
...
...
@@ -405,7 +405,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
type_sequence_label_size
]
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
type_sequence_label_size
]
,
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
...
...
@@ -442,7 +442,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
]
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
]
,
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
...
...
@@ -859,20 +859,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
9
,
4
,
3
,
1722
,
19
,
24
,
6348
,
61
,
977
,
176
,
1772
,
33
,
45
,
970
,
19
,
4185
,
19
,
12943
,
4354
,
153
,
27
,
442
,
22
,
...
...
@@ -922,5 +912,4 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# the men are forced to leave the monastery. Rasputin is forced to return to
output_ids
=
model
.
generate
(
input_ids
,
max_length
=
200
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
tolist
(),
expected_output_ids
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment