Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
870b734b
Commit
870b734b
authored
Apr 15, 2019
by
thomwolf
Browse files
added tokenizers serialization tests
parent
3e65f255
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
32 deletions
+51
-32
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+1
-0
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+5
-1
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+5
-1
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+1
-0
tests/tokenization_openai_test.py
tests/tokenization_openai_test.py
+16
-0
tests/tokenization_test.py
tests/tokenization_test.py
+11
-0
tests/tokenization_transfo_xl_test.py
tests/tokenization_transfo_xl_test.py
+12
-30
No files found.
pytorch_pretrained_bert/tokenization.py
View file @
870b734b
...
@@ -146,6 +146,7 @@ class BertTokenizer(object):
...
@@ -146,6 +146,7 @@ class BertTokenizer(object):
index
=
token_index
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
index
+=
1
return
vocab_file
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
...
...
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
870b734b
...
@@ -188,7 +188,10 @@ class GPT2Tokenizer(object):
...
@@ -188,7 +188,10 @@ class GPT2Tokenizer(object):
return
word
return
word
def
save_vocabulary
(
self
,
vocab_path
):
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
...
@@ -202,6 +205,7 @@ class GPT2Tokenizer(object):
...
@@ -202,6 +205,7 @@ class GPT2Tokenizer(object):
index
=
token_index
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
index
+=
1
index
+=
1
return
vocab_file
,
merge_file
def
encode
(
self
,
text
):
def
encode
(
self
,
text
):
bpe_tokens
=
[]
bpe_tokens
=
[]
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
870b734b
...
@@ -263,7 +263,10 @@ class OpenAIGPTTokenizer(object):
...
@@ -263,7 +263,10 @@ class OpenAIGPTTokenizer(object):
return
out_string
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
...
@@ -277,3 +280,4 @@ class OpenAIGPTTokenizer(object):
...
@@ -277,3 +280,4 @@ class OpenAIGPTTokenizer(object):
index
=
token_index
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
index
+=
1
index
+=
1
return
vocab_file
,
merge_file
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
870b734b
...
@@ -148,6 +148,7 @@ class TransfoXLTokenizer(object):
...
@@ -148,6 +148,7 @@ class TransfoXLTokenizer(object):
index
=
0
index
=
0
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
return
vocab_file
def
build_vocab
(
self
):
def
build_vocab
(
self
):
if
self
.
vocab_file
:
if
self
.
vocab_file
:
...
...
tests/tokenization_openai_test.py
View file @
870b734b
...
@@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
...
@@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
"/tmp/"
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er</w>"
]
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
"<unk>"
]
input_bpe_tokens
=
[
14
,
15
,
20
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
tests/tokenization_test.py
View file @
870b734b
...
@@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase):
...
@@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_chinese
(
self
):
def
test_chinese
(
self
):
tokenizer
=
BasicTokenizer
()
tokenizer
=
BasicTokenizer
()
...
...
tests/tokenization_transfo_xl_test.py
View file @
870b734b
...
@@ -18,9 +18,7 @@ import os
...
@@ -18,9 +18,7 @@ import os
import
unittest
import
unittest
from
io
import
open
from
io
import
open
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
(
TransfoXLTokenizer
,
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
TransfoXLTokenizer
_is_control
,
_is_punctuation
,
_is_whitespace
)
class
TransfoXLTokenizationTest
(
unittest
.
TestCase
):
class
TransfoXLTokenizationTest
(
unittest
.
TestCase
):
...
@@ -43,6 +41,17 @@ class TransfoXLTokenizationTest(unittest.TestCase):
...
@@ -43,6 +41,17 @@ class TransfoXLTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwant
\u00E9
d,running"
)
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
def
test_full_tokenizer_lower
(
self
):
def
test_full_tokenizer_lower
(
self
):
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
True
)
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
True
)
...
@@ -58,33 +67,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
...
@@ -58,33 +67,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
def
test_is_whitespace
(
self
):
self
.
assertTrue
(
_is_whitespace
(
u
" "
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\t
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\r
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\n
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\u00A0
"
))
self
.
assertFalse
(
_is_whitespace
(
u
"A"
))
self
.
assertFalse
(
_is_whitespace
(
u
"-"
))
def
test_is_control
(
self
):
self
.
assertTrue
(
_is_control
(
u
"
\u0005
"
))
self
.
assertFalse
(
_is_control
(
u
"A"
))
self
.
assertFalse
(
_is_control
(
u
" "
))
self
.
assertFalse
(
_is_control
(
u
"
\t
"
))
self
.
assertFalse
(
_is_control
(
u
"
\r
"
))
def
test_is_punctuation
(
self
):
self
.
assertTrue
(
_is_punctuation
(
u
"-"
))
self
.
assertTrue
(
_is_punctuation
(
u
"$"
))
self
.
assertTrue
(
_is_punctuation
(
u
"`"
))
self
.
assertTrue
(
_is_punctuation
(
u
"."
))
self
.
assertFalse
(
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
_is_punctuation
(
u
" "
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment