Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4f8b5f68
Commit
4f8b5f68
authored
Jun 29, 2019
by
thomwolf
Browse files
add fix for serialization of tokenizer
parent
d9184620
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
2 deletions
+37
-2
pytorch_pretrained_bert/tokenization_xlnet.py
pytorch_pretrained_bert/tokenization_xlnet.py
+15
-0
tests/tokenization_xlnet_test.py
tests/tokenization_xlnet_test.py
+22
-2
No files found.
pytorch_pretrained_bert/tokenization_xlnet.py
View file @
4f8b5f68
...
@@ -182,6 +182,21 @@ class XLNetTokenizer(object):
...
@@ -182,6 +182,21 @@ class XLNetTokenizer(object):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
return
state
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
try
:
import
sentencepiece
as
spm
except
ImportError
:
logger
.
warning
(
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
self
.
vocab_file
)
def
set_special_tokens
(
self
,
special_tokens
):
def
set_special_tokens
(
self
,
special_tokens
):
""" Add a list of additional tokens to the encoder.
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
The additional tokens are indexed starting from the last index of the
...
...
tests/tokenization_xlnet_test.py
View file @
4f8b5f68
...
@@ -15,11 +15,17 @@
...
@@ -15,11 +15,17 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
os
import
sys
import
unittest
import
unittest
from
io
import
open
from
io
import
open
import
shutil
import
shutil
import
pytest
import
pytest
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
from
pytorch_pretrained_bert.tokenization_xlnet
import
(
XLNetTokenizer
,
from
pytorch_pretrained_bert.tokenization_xlnet
import
(
XLNetTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
,
PRETRAINED_VOCAB_ARCHIVE_MAP
,
SPIECE_UNDERLINE
)
SPIECE_UNDERLINE
)
...
@@ -43,8 +49,6 @@ class XLNetTokenizationTest(unittest.TestCase):
...
@@ -43,8 +49,6 @@ class XLNetTokenizationTest(unittest.TestCase):
vocab_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
)
vocab_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
)
tokenizer
=
tokenizer
.
from_pretrained
(
vocab_path
,
tokenizer
=
tokenizer
.
from_pretrained
(
vocab_path
,
keep_accents
=
True
)
keep_accents
=
True
)
os
.
remove
(
vocab_file
)
os
.
remove
(
special_tokens_file
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
u
'I'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
u
'I'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
...
@@ -65,6 +69,22 @@ class XLNetTokenizationTest(unittest.TestCase):
...
@@ -65,6 +69,22 @@ class XLNetTokenizationTest(unittest.TestCase):
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
u
'<unk>'
,
u
'.'
])
u
'<unk>'
,
u
'.'
])
text
=
"Munich and Berlin are nice cities"
filename
=
u
"/tmp/tokenizer.bin"
subwords
=
tokenizer
.
tokenize
(
text
)
pickle
.
dump
(
tokenizer
,
open
(
filename
,
"wb"
))
tokenizer_new
=
pickle
.
load
(
open
(
filename
,
"rb"
))
subwords_loaded
=
tokenizer_new
.
tokenize
(
text
)
self
.
assertListEqual
(
subwords
,
subwords_loaded
)
os
.
remove
(
filename
)
os
.
remove
(
vocab_file
)
os
.
remove
(
special_tokens_file
)
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment