Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
870b734b
Commit
870b734b
authored
Apr 15, 2019
by
thomwolf
Browse files
added tokenizers serialization tests
parent
3e65f255
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
32 deletions
+51
-32
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+1
-0
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+5
-1
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+5
-1
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+1
-0
tests/tokenization_openai_test.py
tests/tokenization_openai_test.py
+16
-0
tests/tokenization_test.py
tests/tokenization_test.py
+11
-0
tests/tokenization_transfo_xl_test.py
tests/tokenization_transfo_xl_test.py
+12
-30
No files found.
pytorch_pretrained_bert/tokenization.py
View file @
870b734b
...
...
@@ -146,6 +146,7 @@ class BertTokenizer(object):
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
vocab_file
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
...
...
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
870b734b
...
...
@@ -188,7 +188,10 @@ class GPT2Tokenizer(object):
return
word
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
...
...
@@ -202,6 +205,7 @@ class GPT2Tokenizer(object):
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
def
encode
(
self
,
text
):
bpe_tokens
=
[]
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
870b734b
...
...
@@ -263,7 +263,10 @@ class OpenAIGPTTokenizer(object):
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
...
...
@@ -277,3 +280,4 @@ class OpenAIGPTTokenizer(object):
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
870b734b
...
...
@@ -148,6 +148,7 @@ class TransfoXLTokenizer(object):
index
=
0
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
return
vocab_file
def
build_vocab
(
self
):
if
self
.
vocab_file
:
...
...
tests/tokenization_openai_test.py
View file @
870b734b
...
...
@@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
"/tmp/"
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er</w>"
]
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
"<unk>"
]
input_bpe_tokens
=
[
14
,
15
,
20
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/tokenization_test.py
View file @
870b734b
...
...
@@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_chinese
(
self
):
tokenizer
=
BasicTokenizer
()
...
...
tests/tokenization_transfo_xl_test.py
View file @
870b734b
...
...
@@ -18,9 +18,7 @@ import os
import
unittest
from
io
import
open
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
(
TransfoXLTokenizer
,
_is_control
,
_is_punctuation
,
_is_whitespace
)
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
TransfoXLTokenizer
class
TransfoXLTokenizationTest
(
unittest
.
TestCase
):
...
...
@@ -43,6 +41,17 @@ class TransfoXLTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwant
\u00E9
d,running"
)
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
def
test_full_tokenizer_lower
(
self
):
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
True
)
...
...
@@ -58,33 +67,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
def
test_is_whitespace
(
self
):
self
.
assertTrue
(
_is_whitespace
(
u
" "
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\t
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\r
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\n
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\u00A0
"
))
self
.
assertFalse
(
_is_whitespace
(
u
"A"
))
self
.
assertFalse
(
_is_whitespace
(
u
"-"
))
def
test_is_control
(
self
):
self
.
assertTrue
(
_is_control
(
u
"
\u0005
"
))
self
.
assertFalse
(
_is_control
(
u
"A"
))
self
.
assertFalse
(
_is_control
(
u
" "
))
self
.
assertFalse
(
_is_control
(
u
"
\t
"
))
self
.
assertFalse
(
_is_control
(
u
"
\r
"
))
def
test_is_punctuation
(
self
):
self
.
assertTrue
(
_is_punctuation
(
u
"-"
))
self
.
assertTrue
(
_is_punctuation
(
u
"$"
))
self
.
assertTrue
(
_is_punctuation
(
u
"`"
))
self
.
assertTrue
(
_is_punctuation
(
u
"."
))
self
.
assertFalse
(
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
_is_punctuation
(
u
" "
))
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment