Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4f8b5f68
Commit
4f8b5f68
authored
Jun 29, 2019
by
thomwolf
Browse files
add fix for serialization of tokenizer
parent
d9184620
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
2 deletions
+37
-2
pytorch_pretrained_bert/tokenization_xlnet.py
pytorch_pretrained_bert/tokenization_xlnet.py
+15
-0
tests/tokenization_xlnet_test.py
tests/tokenization_xlnet_test.py
+22
-2
No files found.
pytorch_pretrained_bert/tokenization_xlnet.py
View file @
4f8b5f68
...
...
@@ -182,6 +182,21 @@ class XLNetTokenizer(object):
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
return
state
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
try
:
import
sentencepiece
as
spm
except
ImportError
:
logger
.
warning
(
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
self
.
vocab_file
)
def
set_special_tokens
(
self
,
special_tokens
):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
...
...
tests/tokenization_xlnet_test.py
View file @
4f8b5f68
...
...
@@ -15,11 +15,17 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
sys
import
unittest
from
io
import
open
import
shutil
import
pytest
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
from
pytorch_pretrained_bert.tokenization_xlnet
import
(
XLNetTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
,
SPIECE_UNDERLINE
)
...
...
@@ -43,8 +49,6 @@ class XLNetTokenizationTest(unittest.TestCase):
vocab_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
)
tokenizer
=
tokenizer
.
from_pretrained
(
vocab_path
,
keep_accents
=
True
)
os
.
remove
(
vocab_file
)
os
.
remove
(
special_tokens_file
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
u
'I'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
...
...
@@ -65,6 +69,22 @@ class XLNetTokenizationTest(unittest.TestCase):
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
u
'<unk>'
,
u
'.'
])
text
=
"Munich and Berlin are nice cities"
filename
=
u
"/tmp/tokenizer.bin"
subwords
=
tokenizer
.
tokenize
(
text
)
pickle
.
dump
(
tokenizer
,
open
(
filename
,
"wb"
))
tokenizer_new
=
pickle
.
load
(
open
(
filename
,
"rb"
))
subwords_loaded
=
tokenizer_new
.
tokenize
(
text
)
self
.
assertListEqual
(
subwords
,
subwords_loaded
)
os
.
remove
(
filename
)
os
.
remove
(
vocab_file
)
os
.
remove
(
special_tokens_file
)
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment