Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e645dcbb
Commit
e645dcbb
authored
Feb 25, 2020
by
Patrick von Platen
Browse files
add special tokens to pretrain configs of respective lm head models
parent
e693cd1e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
49 deletions
+5
-49
tests/test_modeling_gpt2.py
tests/test_modeling_gpt2.py
+2
-17
tests/test_modeling_transfo_xl.py
tests/test_modeling_transfo_xl.py
+1
-8
tests/test_modeling_xlm.py
tests/test_modeling_xlm.py
+1
-11
tests/test_modeling_xlnet.py
tests/test_modeling_xlnet.py
+1
-13
No files found.
tests/test_modeling_gpt2.py
View file @
e645dcbb
...
...
@@ -263,14 +263,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
self
.
assertIsNotNone
(
model
)
def
prepare_generation_special_tokens
():
return
{
"bos_token_id"
:
50256
,
"eos_token_id"
:
50256
}
class
GPT2ModelLanguageGenerationTest
(
unittest
.
TestCase
):
special_tokens
=
prepare_generation_special_tokens
()
@
slow
def
test_lm_generate_gpt2
(
self
):
model
=
GPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
...
...
@@ -299,11 +293,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
]
# The dog is cute too. It likes to rub on me and is good for me (the dog
torch
.
manual_seed
(
0
)
output_ids
=
model
.
generate
(
input_ids
,
bos_token_id
=
self
.
special_tokens
[
"bos_token_id"
],
eos_token_ids
=
self
.
special_tokens
[
"eos_token_id"
],
)
output_ids
=
model
.
generate
(
input_ids
)
self
.
assertListEqual
(
output_ids
[
0
].
tolist
(),
expected_output_ids
)
...
...
@@ -335,10 +325,5 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
]
# The dog is cute though he can sometimes just walk in the park. It is not very nice to
torch
.
manual_seed
(
0
)
output_ids
=
model
.
generate
(
input_ids
,
bos_token_id
=
self
.
special_tokens
[
"bos_token_id"
],
eos_token_ids
=
self
.
special_tokens
[
"eos_token_id"
],
)
output_ids
=
model
.
generate
(
input_ids
)
self
.
assertListEqual
(
output_ids
[
0
].
tolist
(),
expected_output_ids
)
tests/test_modeling_transfo_xl.py
View file @
e645dcbb
...
...
@@ -214,14 +214,8 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self
.
assertIsNotNone
(
model
)
def
prepare_generation_special_tokens
():
return
{
"eos_token_id"
:
0
}
class
TransfoXLModelLanguageGenerationTest
(
unittest
.
TestCase
):
special_tokens
=
prepare_generation_special_tokens
()
@
slow
def
test_lm_generate_transfo_xl_wt103
(
self
):
model
=
TransfoXLLMHeadModel
.
from_pretrained
(
"transfo-xl-wt103"
)
...
...
@@ -578,6 +572,5 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
torch
.
manual_seed
(
0
)
output_ids
=
model
.
generate
(
input_ids
,
eos_token_ids
=
self
.
special_tokens
[
"eos_token_id"
],
max_length
=
200
)
output_ids
=
model
.
generate
(
input_ids
,
max_length
=
200
)
self
.
assertListEqual
(
output_ids
[
0
].
tolist
(),
expected_output_ids
)
tests/test_modeling_xlm.py
View file @
e645dcbb
...
...
@@ -399,14 +399,8 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
self
.
assertIsNotNone
(
model
)
def
prepare_generation_special_tokens
():
return
{
"bos_token_id"
:
0
,
"pad_token_id"
:
2
}
class
XLMModelLanguageGenerationTest
(
unittest
.
TestCase
):
special_tokens
=
prepare_generation_special_tokens
()
@
slow
def
test_lm_generate_xlm_mlm_en_2048
(
self
):
model
=
XLMWithLMHeadModel
.
from_pretrained
(
"xlm-mlm-en-2048"
)
...
...
@@ -435,10 +429,6 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
]
# The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
torch
.
manual_seed
(
0
)
output_ids
=
model
.
generate
(
input_ids
,
bos_token_id
=
self
.
special_tokens
[
"bos_token_id"
],
pad_token_id
=
self
.
special_tokens
[
"pad_token_id"
],
)
output_ids
=
model
.
generate
(
input_ids
)
self
.
assertListEqual
(
output_ids
[
0
].
tolist
(),
expected_output_ids
)
tests/test_modeling_xlnet.py
View file @
e645dcbb
...
...
@@ -513,14 +513,8 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self
.
assertIsNotNone
(
model
)
def
prepare_generation_special_tokens
():
return
{
"bos_token_id"
:
1
,
"pad_token_id"
:
5
,
"eos_token_id"
:
2
}
class
XLNetModelLanguageGenerationTest
(
unittest
.
TestCase
):
special_tokens
=
prepare_generation_special_tokens
()
@
slow
def
test_lm_generate_xlnet_base_cased
(
self
):
model
=
XLNetLMHeadModel
.
from_pretrained
(
"xlnet-base-cased"
)
...
...
@@ -917,12 +911,6 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# Since, however, he has had difficulty walking with Maria
torch
.
manual_seed
(
0
)
output_ids
=
model
.
generate
(
input_ids
,
bos_token_id
=
self
.
special_tokens
[
"bos_token_id"
],
pad_token_id
=
self
.
special_tokens
[
"pad_token_id"
],
eos_token_ids
=
self
.
special_tokens
[
"eos_token_id"
],
max_length
=
200
,
)
output_ids
=
model
.
generate
(
input_ids
,
max_length
=
200
)
self
.
assertListEqual
(
output_ids
[
0
].
tolist
(),
expected_output_ids
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment