Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b12541c4
Commit
b12541c4
authored
Mar 09, 2020
by
Patrick von Platen
Browse files
test ctrl
parent
b73dd1a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
39 deletions
+39
-39
tests/test_modeling_ctrl.py
tests/test_modeling_ctrl.py
+20
-20
tests/test_modeling_tf_ctrl.py
tests/test_modeling_tf_ctrl.py
+19
-19
No files found.
tests/test_modeling_ctrl.py
View file @
b12541c4
...
...
@@ -220,30 +220,30 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
def
test_lm_generate_ctrl
(
self
):
model
=
CTRLLMHeadModel
.
from_pretrained
(
"ctrl"
)
input_ids
=
torch
.
tensor
(
[[
1185
8
,
586
,
20984
,
8
]],
dtype
=
torch
.
long
,
device
=
torch_device
)
# Legal
My neighbor
is
[[
1185
9
,
0
,
1611
,
8
]],
dtype
=
torch
.
long
,
device
=
torch_device
)
# Legal
the president
is
expected_output_ids
=
[
11859
,
586
,
20984
,
0
,
1611
,
8
,
13391
,
3
,
980
,
8258
,
72
,
327
,
148
,
5
,
150
,
26449
,
2
,
53
,
29
,
226
,
3
,
780
,
49
,
19
,
348
,
469
,
3
,
980
,
]
# Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
2595
,
48
,
20740
,
246533
,
246533
,
19
,
30
,
5
,
]
# Legal the president is a good guy and I don't want to lose my job. \n \n I have a
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
self
.
assertListEqual
(
output_ids
[
0
].
tolist
(),
expected_output_ids
)
tests/test_modeling_tf_ctrl.py
View file @
b12541c4
...
...
@@ -209,29 +209,29 @@ class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
@
slow
def
test_lm_generate_ctrl
(
self
):
model
=
TFCTRLLMHeadModel
.
from_pretrained
(
"ctrl"
)
input_ids
=
tf
.
convert_to_tensor
([[
1185
8
,
586
,
20984
,
8
]],
dtype
=
tf
.
int32
)
input_ids
=
tf
.
convert_to_tensor
([[
1185
9
,
0
,
1611
,
8
]],
dtype
=
tf
.
int32
)
# Legal the president is
expected_output_ids
=
[
11859
,
586
,
20984
,
0
,
1611
,
8
,
13391
,
3
,
980
,
8258
,
72
,
327
,
148
,
5
,
150
,
26449
,
2
,
53
,
29
,
226
,
3
,
780
,
49
,
19
,
348
,
469
,
3
,
980
,
]
# Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
2595
,
48
,
20740
,
246533
,
246533
,
19
,
30
,
5
,
]
# Legal the president is a good guy and I don't want to lose my job. \n \n I have a
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
tolist
(),
expected_output_ids
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment