Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
36bca545
Commit
36bca545
authored
Jul 05, 2019
by
thomwolf
Browse files
tokenization abstract class - tests for examples
parent
a4f98054
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
349 additions
and
468 deletions
+349
-468
pytorch_transformers/tests/tokenization_openai_test.py
pytorch_transformers/tests/tokenization_openai_test.py
+1
-9
pytorch_transformers/tests/tokenization_transfo_xl_test.py
pytorch_transformers/tests/tokenization_transfo_xl_test.py
+1
-8
pytorch_transformers/tests/tokenization_utils_test.py
pytorch_transformers/tests/tokenization_utils_test.py
+36
-0
pytorch_transformers/tests/tokenization_xlm_test.py
pytorch_transformers/tests/tokenization_xlm_test.py
+2
-10
pytorch_transformers/tests/tokenization_xlnet_test.py
pytorch_transformers/tests/tokenization_xlnet_test.py
+1
-11
pytorch_transformers/tokenization_bert.py
pytorch_transformers/tokenization_bert.py
+19
-47
pytorch_transformers/tokenization_gpt2.py
pytorch_transformers/tokenization_gpt2.py
+41
-76
pytorch_transformers/tokenization_openai.py
pytorch_transformers/tokenization_openai.py
+38
-72
pytorch_transformers/tokenization_transfo_xl.py
pytorch_transformers/tokenization_transfo_xl.py
+27
-51
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+114
-0
pytorch_transformers/tokenization_xlm.py
pytorch_transformers/tokenization_xlm.py
+44
-78
pytorch_transformers/tokenization_xlnet.py
pytorch_transformers/tokenization_xlnet.py
+25
-106
No files found.
pytorch_transformers/tests/tokenization_openai_test.py
View file @
36bca545
...
@@ -20,7 +20,7 @@ import json
...
@@ -20,7 +20,7 @@ import json
import
shutil
import
shutil
import
pytest
import
pytest
from
pytorch_transformers.tokenization_openai
import
OpenAIGPTTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
pytorch_transformers.tokenization_openai
import
OpenAIGPTTokenizer
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
...
@@ -58,14 +58,6 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
...
@@ -58,14 +58,6 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
OpenAIGPTTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
pytorch_transformers/tests/tokenization_transfo_xl_test.py
View file @
36bca545
...
@@ -20,7 +20,7 @@ from io import open
...
@@ -20,7 +20,7 @@ from io import open
import
shutil
import
shutil
import
pytest
import
pytest
from
pytorch_transformers.tokenization_transfo_xl
import
TransfoXLTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
pytorch_transformers.tokenization_transfo_xl
import
TransfoXLTokenizer
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
...
@@ -59,13 +59,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
...
@@ -59,13 +59,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer
.
tokenize
(
u
"
\t
HeLLo ! how
\n
Are yoU ? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo ! how
\n
Are yoU ? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
TransfoXLTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
pytorch_transformers/tests/tokenization_utils_test.py
0 → 100644
View file @
36bca545
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
from
pytorch_transformers
import
PreTrainedTokenizer
from
pytorch_transformers.tokenization_gpt2
import
GPT2Tokenizer
class
TokenizerUtilsTest
(
unittest
.
TestCase
):
def
check_tokenizer_from_pretrained
(
self
,
tokenizer_class
):
s3_models
=
list
(
tokenizer_class
.
max_model_input_sizes
.
keys
())
for
model_name
in
s3_models
[:
1
]:
tokenizer
=
tokenizer_class
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
tokenizer
)
self
.
assertIsInstance
(
tokenizer
,
PreTrainedTokenizer
)
def
test_pretrained_tokenizers
(
self
):
self
.
check_tokenizer_from_pretrained
(
GPT2Tokenizer
)
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_transformers/tests/tokenization_xlm_test.py
View file @
36bca545
...
@@ -20,9 +20,9 @@ import json
...
@@ -20,9 +20,9 @@ import json
import
shutil
import
shutil
import
pytest
import
pytest
from
pytorch_transformers.tokenization_xlm
import
XLMTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
pytorch_transformers.tokenization_xlm
import
XLMTokenizer
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
class
XLMTokenizationTest
(
unittest
.
TestCase
):
class
XLMTokenizationTest
(
unittest
.
TestCase
):
...
@@ -57,14 +57,6 @@ class XLMTokenizationTest(unittest.TestCase):
...
@@ -57,14 +57,6 @@ class XLMTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
XLMTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
pytorch_transformers/tests/tokenization_xlnet_test.py
View file @
36bca545
...
@@ -19,9 +19,7 @@ import unittest
...
@@ -19,9 +19,7 @@ import unittest
import
shutil
import
shutil
import
pytest
import
pytest
from
pytorch_transformers.tokenization_xlnet
import
(
XLNetTokenizer
,
from
pytorch_transformers.tokenization_xlnet
import
(
XLNetTokenizer
,
SPIECE_UNDERLINE
)
PRETRAINED_VOCAB_ARCHIVE_MAP
,
SPIECE_UNDERLINE
)
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
...
@@ -60,14 +58,6 @@ class XLNetTokenizationTest(unittest.TestCase):
...
@@ -60,14 +58,6 @@ class XLNetTokenizationTest(unittest.TestCase):
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
u
'<unk>'
,
u
'.'
])
u
'<unk>'
,
u
'.'
])
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
XLNetTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
def
test_tokenizer_lower
(
self
):
def
test_tokenizer_lower
(
self
):
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
do_lower_case
=
True
)
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
do_lower_case
=
True
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
...
...
pytorch_transformers/tokenization_bert.py
View file @
36bca545
...
@@ -23,11 +23,15 @@ import unicodedata
...
@@ -23,11 +23,15 @@ import unicodedata
from
io
import
open
from
io
import
open
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.
model
_utils
import
clean_up_tokenization
from
.
tokenization
_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'vocab_file'
:
'vocab.txt'
}
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
{
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt"
,
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt"
,
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt"
,
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt"
,
...
@@ -41,8 +45,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
...
@@ -41,8 +45,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt"
,
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt"
,
'bert-large-cased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt"
,
'bert-large-cased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt"
,
'bert-base-cased-finetuned-mrpc'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt"
,
'bert-base-cased-finetuned-mrpc'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt"
,
}
}}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'bert-base-uncased'
:
512
,
'bert-base-uncased'
:
512
,
'bert-large-uncased'
:
512
,
'bert-large-uncased'
:
512
,
'bert-base-cased'
:
512
,
'bert-base-cased'
:
512
,
...
@@ -57,7 +62,6 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
...
@@ -57,7 +62,6 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad'
:
512
,
'bert-large-cased-whole-word-masking-finetuned-squad'
:
512
,
'bert-base-cased-finetuned-mrpc'
:
512
,
'bert-base-cased-finetuned-mrpc'
:
512
,
}
}
VOCAB_NAME
=
'vocab.txt'
def
load_vocab
(
vocab_file
):
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
"""Loads a vocabulary file into a dictionary."""
...
@@ -83,8 +87,11 @@ def whitespace_tokenize(text):
...
@@ -83,8 +87,11 @@ def whitespace_tokenize(text):
return
tokens
return
tokens
class
BertTokenizer
(
object
):
class
BertTokenizer
(
PreTrainedTokenizer
):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
max_len
=
None
,
do_basic_tokenize
=
True
,
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
max_len
=
None
,
do_basic_tokenize
=
True
,
never_split
=
(
"[UNK]"
,
"[SEP]"
,
"[PAD]"
,
"[CLS]"
,
"[MASK]"
)):
never_split
=
(
"[UNK]"
,
"[SEP]"
,
"[PAD]"
,
"[CLS]"
,
"[MASK]"
)):
...
@@ -203,7 +210,7 @@ class BertTokenizer(object):
...
@@ -203,7 +210,7 @@ class BertTokenizer(object):
"""Save the tokenizer vocabulary to a directory or file."""
"""Save the tokenizer vocabulary to a directory or file."""
index
=
0
index
=
0
if
os
.
path
.
isdir
(
vocab_path
):
if
os
.
path
.
isdir
(
vocab_path
):
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
FILES_NAMES
[
'vocab_file'
]
)
with
open
(
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
vocab
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
for
token
,
token_index
in
sorted
(
self
.
vocab
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
if
index
!=
token_index
:
...
@@ -215,13 +222,10 @@ class BertTokenizer(object):
...
@@ -215,13 +222,10 @@ class BertTokenizer(object):
return
(
vocab_file
,)
return
(
vocab_file
,)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
if
pretrained_model_name_or_path
in
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
if
'-cased'
in
pretrained_model_name_or_path
and
kwargs
.
get
(
'do_lower_case'
,
True
):
if
'-cased'
in
pretrained_model_name_or_path
and
kwargs
.
get
(
'do_lower_case'
,
True
):
logger
.
warning
(
"The pre-trained model you are loading is a cased model but you have not set "
logger
.
warning
(
"The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
...
@@ -232,40 +236,8 @@ class BertTokenizer(object):
...
@@ -232,40 +236,8 @@ class BertTokenizer(object):
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior."
)
"but you may want to check this behavior."
)
kwargs
[
'do_lower_case'
]
=
True
kwargs
[
'do_lower_case'
]
=
True
else
:
vocab_file
=
pretrained_model_name_or_path
return
super
(
BertTokenizer
,
cls
).
_from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
if
os
.
path
.
isdir
(
vocab_file
):
vocab_file
=
os
.
path
.
join
(
vocab_file
,
VOCAB_NAME
)
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
vocab_file
))
return
None
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
*
inputs
,
**
kwargs
)
return
tokenizer
class
BasicTokenizer
(
object
):
class
BasicTokenizer
(
object
):
...
...
pytorch_transformers/tokenization_gpt2.py
View file @
36bca545
...
@@ -23,8 +23,6 @@ import os
...
@@ -23,8 +23,6 @@ import os
import
regex
as
re
import
regex
as
re
from
io
import
open
from
io
import
open
from
.model_utils
import
clean_up_tokenization
try
:
try
:
from
functools
import
lru_cache
from
functools
import
lru_cache
except
ImportError
:
except
ImportError
:
...
@@ -33,24 +31,38 @@ except ImportError:
...
@@ -33,24 +31,38 @@ except ImportError:
def
lru_cache
():
def
lru_cache
():
return
lambda
func
:
func
return
lambda
func
:
func
from
.
file
_utils
import
cached_path
from
.
tokenization
_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
'vocab_file'
:
'vocab.json'
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json"
,
'merges_file'
:
'merges.txt'
,
'special_tokens_file'
:
'special_tokens.txt'
}
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
PRETRAINED_VOCAB_FILES_MAP
=
{
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt"
,
'vocab_file'
:
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json"
,
},
'merges_file'
:
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt"
,
},
'special_tokens_file'
:
{
'gpt2'
:
None
,
'gpt2-medium'
:
None
,
}
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'gpt2'
:
1024
,
'gpt2'
:
1024
,
'gpt2-medium'
:
1024
,
}
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
@
lru_cache
()
@
lru_cache
()
def
bytes_to_unicode
():
def
bytes_to_unicode
():
...
@@ -87,70 +99,16 @@ def get_pairs(word):
...
@@ -87,70 +99,16 @@ def get_pairs(word):
prev_char
=
char
prev_char
=
char
return
pairs
return
pairs
class
GPT2Tokenizer
(
object
):
class
GPT2Tokenizer
(
PreTrainedTokenizer
):
"""
"""
GPT-2 BPE tokenizer. Peculiarities:
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
- Byte-level BPE
"""
"""
@
classmethod
vocab_files_names
=
VOCAB_FILES_NAMES
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
"""
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a GPT2Tokenizer from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens_file
=
None
,
special_tokens
=
None
,
errors
=
'replace'
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
...
@@ -165,9 +123,16 @@ class GPT2Tokenizer(object):
...
@@ -165,9 +123,16 @@ class GPT2Tokenizer(object):
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
all_special_tokens
=
[]
if
special_tokens_file
is
not
None
:
special_tokens_to_add
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
all_special_tokens
.
extend
(
special_tokens_to_add
)
if
special_tokens
is
not
None
and
special_tokens
:
all_special_tokens
.
extend
(
special_tokens
)
self
.
special_tokens
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
special_tokens_decoder
=
{}
self
.
set_special_tokens
(
special_tokens
)
self
.
set_special_tokens
(
all_
special_tokens
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
...
@@ -285,9 +250,9 @@ class GPT2Tokenizer(object):
...
@@ -285,9 +250,9 @@ class GPT2Tokenizer(object):
if
not
os
.
path
.
isdir
(
vocab_path
):
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
FILES_NAMES
[
'vocab_file'
]
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_FILES_NAMES
[
'merges_file'
]
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_FILES_NAMES
[
'special_tokens_file'
]
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
...
...
pytorch_transformers/tokenization_openai.py
View file @
36bca545
...
@@ -26,23 +26,35 @@ from io import open
...
@@ -26,23 +26,35 @@ from io import open
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.
model
_utils
import
clean_up_tokenization
from
.
tokenization
_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
from
.tokenization_bert
import
BasicTokenizer
from
.tokenization_bert
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json"
,
'vocab_file'
:
'vocab.json'
,
'merges_file'
:
'merges.txt'
,
'special_tokens_file'
:
'special_tokens.txt'
}
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt"
,
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
{
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json"
,
},
'merges_file'
:
{
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt"
,
},
'special_tokens_file'
:
{
'openai-gpt'
:
None
,
}
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'openai-gpt'
:
512
,
'openai-gpt'
:
512
,
}
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
def
get_pairs
(
word
):
def
get_pairs
(
word
):
"""
"""
...
@@ -71,7 +83,7 @@ def text_standardize(text):
...
@@ -71,7 +83,7 @@ def text_standardize(text):
text
=
re
.
sub
(
r
'[^\S\n]+'
,
' '
,
text
)
text
=
re
.
sub
(
r
'[^\S\n]+'
,
' '
,
text
)
return
text
.
strip
()
return
text
.
strip
()
class
OpenAIGPTTokenizer
(
object
):
class
OpenAIGPTTokenizer
(
PreTrainedTokenizer
):
"""
"""
BPE tokenizer. Peculiarities:
BPE tokenizer. Peculiarities:
- lower case all inputs
- lower case all inputs
...
@@ -79,65 +91,11 @@ class OpenAIGPTTokenizer(object):
...
@@ -79,65 +91,11 @@ class OpenAIGPTTokenizer(object):
- argument special_tokens and function set_special_tokens:
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
"""
@
classmethod
vocab_files_names
=
VOCAB_FILES_NAMES
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
"""
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens
=
None
,
max_len
=
None
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens_file
=
None
,
special_tokens
=
None
,
max_len
=
None
):
try
:
try
:
import
ftfy
import
ftfy
import
spacy
import
spacy
...
@@ -156,9 +114,17 @@ class OpenAIGPTTokenizer(object):
...
@@ -156,9 +114,17 @@ class OpenAIGPTTokenizer(object):
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{}
self
.
cache
=
{}
all_special_tokens
=
[]
if
special_tokens_file
is
not
None
:
special_tokens_to_add
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
all_special_tokens
.
extend
(
special_tokens_to_add
)
if
special_tokens
is
not
None
and
special_tokens
:
all_special_tokens
.
extend
(
special_tokens
)
self
.
special_tokens
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
special_tokens_decoder
=
{}
self
.
set_special_tokens
(
special_tokens
)
self
.
set_special_tokens
(
all_
special_tokens
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
...
@@ -286,9 +252,9 @@ class OpenAIGPTTokenizer(object):
...
@@ -286,9 +252,9 @@ class OpenAIGPTTokenizer(object):
if
not
os
.
path
.
isdir
(
vocab_path
):
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
FILES_NAMES
[
'vocab_file'
]
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_FILES_NAMES
[
'merges_file'
]
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_FILES_NAMES
[
'special_tokens_file'
]
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
...
...
pytorch_transformers/tokenization_transfo_xl.py
View file @
36bca545
...
@@ -31,7 +31,7 @@ import torch
...
@@ -31,7 +31,7 @@ import torch
import
numpy
as
np
import
numpy
as
np
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.
model
_utils
import
clean_up_tokenization
from
.
tokenization
_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
if
sys
.
version_info
[
0
]
==
2
:
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
import
cPickle
as
pickle
...
@@ -41,66 +41,35 @@ else:
...
@@ -41,66 +41,35 @@ else:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'pretrained_vocab_file'
:
'vocab.bin'
}
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin"
,
PRETRAINED_VOCAB_FILES_MAP
=
{
'pretrained_vocab_file'
:
{
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin"
,
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'transfo-xl-wt103'
:
512
,
}
}
VOCAB_NAME
=
'vocab.bin'
PRETRAINED_CORPUS_ARCHIVE_MAP
=
{
PRETRAINED_CORPUS_ARCHIVE_MAP
=
{
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin"
,
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin"
,
}
}
CORPUS_NAME
=
'corpus.bin'
CORPUS_NAME
=
'corpus.bin'
class
TransfoXLTokenizer
(
object
):
class
TransfoXLTokenizer
(
PreTrainedTokenizer
):
"""
"""
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
"""
"""
@
classmethod
vocab_files_names
=
VOCAB_FILES_NAMES
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
"""
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a TransfoXLTokenizer.
The TransfoXLTokenizer.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
else
:
vocab_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
))
return
None
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
# Instantiate tokenizer.
tokenizer
=
cls
(
*
inputs
,
**
kwargs
)
vocab_dict
=
torch
.
load
(
resolved_vocab_file
)
for
key
,
value
in
vocab_dict
.
items
():
tokenizer
.
__dict__
[
key
]
=
value
return
tokenizer
def
__init__
(
self
,
special
=
[],
min_freq
=
0
,
max_size
=
None
,
lower_case
=
False
,
def
__init__
(
self
,
special
=
[],
min_freq
=
0
,
max_size
=
None
,
lower_case
=
False
,
delimiter
=
None
,
vocab_file
=
None
,
never_split
=
(
"<unk>"
,
"<eos>"
,
"<formula>"
)):
delimiter
=
None
,
vocab_file
=
None
,
pretrained_vocab_file
=
None
,
never_split
=
(
"<unk>"
,
"<eos>"
,
"<formula>"
)):
self
.
counter
=
Counter
()
self
.
counter
=
Counter
()
self
.
special
=
special
self
.
special
=
special
self
.
min_freq
=
min_freq
self
.
min_freq
=
min_freq
...
@@ -110,6 +79,13 @@ class TransfoXLTokenizer(object):
...
@@ -110,6 +79,13 @@ class TransfoXLTokenizer(object):
self
.
vocab_file
=
vocab_file
self
.
vocab_file
=
vocab_file
self
.
never_split
=
never_split
self
.
never_split
=
never_split
if
pretrained_vocab_file
is
not
None
:
# Hack because, honestly this tokenizer was not made to be used
# in a library like ours, at all.
vocab_dict
=
torch
.
load
(
pretrained_vocab_file
)
for
key
,
value
in
vocab_dict
.
items
():
self
.
__dict__
[
key
]
=
value
if
vocab_file
is
not
None
:
if
vocab_file
is
not
None
:
self
.
build_vocab
()
self
.
build_vocab
()
...
@@ -157,7 +133,7 @@ class TransfoXLTokenizer(object):
...
@@ -157,7 +133,7 @@ class TransfoXLTokenizer(object):
"""Save the tokenizer vocabulary to a directory or file."""
"""Save the tokenizer vocabulary to a directory or file."""
index
=
0
index
=
0
if
os
.
path
.
isdir
(
vocab_path
):
if
os
.
path
.
isdir
(
vocab_path
):
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
FILES_NAMES
[
'pretrained_vocab_file'
]
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
return
(
vocab_file
,)
return
(
vocab_file
,)
...
@@ -484,7 +460,7 @@ class TransfoXLCorpus(object):
...
@@ -484,7 +460,7 @@ class TransfoXLCorpus(object):
"We assumed '{}' was a path or url but couldn't find files {} "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_
VOCAB
_ARCHIVE_MAP
.
keys
()),
', '
.
join
(
PRETRAINED_
CORPUS
_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
corpus_file
))
corpus_file
))
return
None
return
None
...
...
pytorch_transformers/tokenization_utils.py
0 → 100644
View file @
36bca545
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
sys
import
json
import
logging
import
os
import
regex
as
re
from
io
import
open
try
:
from
functools
import
lru_cache
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def
lru_cache
():
return
lambda
func
:
func
from
.file_utils
import
cached_path
logger
=
logging
.
getLogger
(
__name__
)
class
PreTrainedTokenizer
(
object
):
""" An abstract class to handle dowloading and loading pretrained tokenizers.
"""
vocab_files_names
=
{}
pretrained_vocab_files_map
=
{}
max_model_input_sizes
=
{}
@
classmethod
def
from_pretrained
(
cls
,
*
inputs
,
**
kwargs
):
return
cls
.
_from_pretrained
(
*
inputs
,
**
kwargs
)
@
classmethod
def
_from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
Download and cache the vocabulary files if needed.
"""
s3_models
=
list
(
cls
.
max_model_input_sizes
.
keys
())
vocab_files
=
{}
if
pretrained_model_name_or_path
in
s3_models
:
for
file_id
,
map_list
in
cls
.
pretrained_vocab_files_map
.
items
():
vocab_files
[
file_id
]
=
map_list
[
pretrained_model_name_or_path
]
else
:
for
file_id
,
file_name
in
cls
.
vocab_files_names
.
items
():
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
full_file_name
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
file_name
)
else
:
full_file_name
=
pretrained_model_name_or_path
if
not
os
.
path
.
exists
(
full_file_name
):
logger
.
info
(
"Didn't find file {}. We don't load it."
.
format
(
full_file_name
))
full_file_name
=
None
vocab_files
[
file_id
]
=
full_file_name
# redirect to the cache, if necessary
try
:
resolved_vocab_files
=
{}
for
file_id
,
file_path
in
vocab_files
.
items
():
if
file_path
is
None
:
resolved_vocab_files
[
file_id
]
=
None
else
:
resolved_vocab_files
[
file_id
]
=
cached_path
(
file_path
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
s3_models
:
logger
.
error
(
"Couldn't reach server to download vocabulary."
)
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
str
(
vocab_files
.
keys
())))
return
None
for
file_id
,
file_path
in
vocab_files
.
items
():
if
file_path
==
resolved_vocab_files
[
file_id
]:
logger
.
info
(
"loading file {}"
.
format
(
file_path
))
else
:
logger
.
info
(
"loading file {} from cache at {}"
.
format
(
file_path
,
resolved_vocab_files
[
file_id
]))
if
pretrained_model_name_or_path
in
cls
.
max_model_input_sizes
:
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
max_len
=
cls
.
max_model_input_sizes
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
tokenizer
=
cls
(
*
inputs
,
**
resolved_vocab_files
,
**
kwargs
)
return
tokenizer
def
clean_up_tokenization
(
out_string
):
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
out_string
pytorch_transformers/tokenization_xlm.py
View file @
36bca545
...
@@ -26,30 +26,42 @@ from io import open
...
@@ -26,30 +26,42 @@ from io import open
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.
model
_utils
import
clean_up_tokenization
from
.
tokenization
_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
from
.tokenization_bert
import
BasicTokenizer
from
.tokenization_bert
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json"
,
'vocab_file'
:
'vocab.json'
,
'merges_file'
:
'merges.txt'
,
'special_tokens_file'
:
'special_tokens.txt'
}
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt"
,
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
{
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json"
,
},
'merges_file'
:
{
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt"
,
},
'special_tokens_file'
:
{
'xlm-mlm-en-2048'
:
None
,
}
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'xlm-mlm-en-2048'
:
512
,
'xlm-mlm-en-2048'
:
512
,
}
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
INDEX
=
{
INDEX
=
{
"bos_index"
:
0
,
"bos_index"
:
0
,
"eos_index"
:
1
,
"eos_index"
:
1
,
"pad_index"
:
2
,
"pad_index"
:
2
,
"unk_index"
:
3
,
"unk_index"
:
3
,
"mask_index"
:
5
"mask_index"
:
5
}
}
def
get_pairs
(
word
):
def
get_pairs
(
word
):
...
@@ -79,7 +91,7 @@ def text_standardize(text):
...
@@ -79,7 +91,7 @@ def text_standardize(text):
text
=
re
.
sub
(
r
'[^\S\n]+'
,
' '
,
text
)
text
=
re
.
sub
(
r
'[^\S\n]+'
,
' '
,
text
)
return
text
.
strip
()
return
text
.
strip
()
class
XLMTokenizer
(
object
):
class
XLMTokenizer
(
PreTrainedTokenizer
):
"""
"""
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
- lower case all inputs
- lower case all inputs
...
@@ -87,65 +99,11 @@ class XLMTokenizer(object):
...
@@ -87,65 +99,11 @@ class XLMTokenizer(object):
- argument special_tokens and function set_special_tokens:
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
"""
@
classmethod
vocab_files_names
=
VOCAB_FILES_NAMES
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
"""
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens
=
None
,
max_len
=
None
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens_file
=
None
,
special_tokens
=
None
,
max_len
=
None
):
try
:
try
:
import
ftfy
import
ftfy
import
spacy
import
spacy
...
@@ -164,9 +122,17 @@ class XLMTokenizer(object):
...
@@ -164,9 +122,17 @@ class XLMTokenizer(object):
merges
=
[
tuple
(
merge
.
split
()[:
2
])
for
merge
in
merges
]
merges
=
[
tuple
(
merge
.
split
()[:
2
])
for
merge
in
merges
]
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{}
self
.
cache
=
{}
all_special_tokens
=
[]
if
special_tokens_file
is
not
None
:
special_tokens_to_add
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
all_special_tokens
.
extend
(
special_tokens_to_add
)
if
special_tokens
is
not
None
and
special_tokens
:
all_special_tokens
.
extend
(
special_tokens
)
self
.
special_tokens
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
special_tokens_decoder
=
{}
self
.
set_special_tokens
(
special_tokens
)
self
.
set_special_tokens
(
all_
special_tokens
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
...
@@ -294,9 +260,9 @@ class XLMTokenizer(object):
...
@@ -294,9 +260,9 @@ class XLMTokenizer(object):
if
not
os
.
path
.
isdir
(
vocab_path
):
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
FILES_NAMES
[
'vocab_file'
]
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_FILES_NAMES
[
'merges_file'
]
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_FILES_NAMES
[
'special_tokens_file'
]
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
...
...
pytorch_transformers/tokenization_xlnet.py
View file @
36bca545
...
@@ -27,15 +27,24 @@ import unicodedata
...
@@ -27,15 +27,24 @@ import unicodedata
import
six
import
six
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.
model
_utils
import
clean_up_tokenization
from
.
tokenization
_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'vocab_file'
:
'spiece.model'
}
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
{
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model"
,
}
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'xlnet-large-cased'
:
512
,
}
VOCAB_NAME
=
'spiece.model'
VOCAB_NAME
=
'spiece.model'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
SPIECE_UNDERLINE
=
u
'▁'
SPIECE_UNDERLINE
=
u
'▁'
...
@@ -46,7 +55,7 @@ SEG_ID_CLS = 2
...
@@ -46,7 +55,7 @@ SEG_ID_CLS = 2
SEG_ID_SEP
=
3
SEG_ID_SEP
=
3
SEG_ID_PAD
=
4
SEG_ID_PAD
=
4
class
XLNetTokenizer
(
object
):
class
XLNetTokenizer
(
PreTrainedTokenizer
):
"""
"""
SentencePiece based tokenizer. Peculiarities:
SentencePiece based tokenizer. Peculiarities:
- requires SentencePiece: https://github.com/google/sentencepiece
- requires SentencePiece: https://github.com/google/sentencepiece
...
@@ -63,64 +72,11 @@ class XLNetTokenizer(object):
...
@@ -63,64 +72,11 @@ class XLNetTokenizer(object):
"<eod>"
:
7
,
"<eod>"
:
7
,
"<eop>"
:
8
,
"<eop>"
:
8
,
}
}
@
classmethod
vocab_files_names
=
VOCAB_FILES_NAMES
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
"""
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
if
'-cased'
in
pretrained_model_name_or_path
and
kwargs
.
get
(
'do_lower_case'
,
True
):
logger
.
warning
(
"The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior."
)
kwargs
[
'do_lower_case'
]
=
False
elif
'-cased'
not
in
pretrained_model_name_or_path
and
not
kwargs
.
get
(
'do_lower_case'
,
True
):
logger
.
warning
(
"The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior."
)
kwargs
[
'do_lower_case'
]
=
True
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {}"
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
))
return
None
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
# Instantiate tokenizer.
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
special_tokens
=
None
,
max_len
=
None
,
def
__init__
(
self
,
vocab_file
,
max_len
=
None
,
do_lower_case
=
False
,
remove_space
=
True
,
keep_accents
=
False
):
do_lower_case
=
False
,
remove_space
=
True
,
keep_accents
=
False
):
try
:
try
:
import
sentencepiece
as
spm
import
sentencepiece
as
spm
...
@@ -136,9 +92,6 @@ class XLNetTokenizer(object):
...
@@ -136,9 +92,6 @@ class XLNetTokenizer(object):
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
vocab_file
)
self
.
sp_model
.
Load
(
vocab_file
)
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
set_special_tokens
(
special_tokens
)
@
property
@
property
def
UNK_TOKEN
(
self
):
def
UNK_TOKEN
(
self
):
...
@@ -181,7 +134,7 @@ class XLNetTokenizer(object):
...
@@ -181,7 +134,7 @@ class XLNetTokenizer(object):
return
self
.
special_symbols
[
"<mask>"
]
return
self
.
special_symbols
[
"<mask>"
]
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
return
len
(
self
.
sp_model
)
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
=
self
.
__dict__
.
copy
()
...
@@ -198,19 +151,6 @@ class XLNetTokenizer(object):
...
@@ -198,19 +151,6 @@ class XLNetTokenizer(object):
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
self
.
vocab_file
)
self
.
sp_model
.
Load
(
self
.
vocab_file
)
def
set_special_tokens
(
self
,
special_tokens
):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if
not
special_tokens
:
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
sp_model
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
logger
.
info
(
"Special tokens: %s"
,
str
(
self
.
special_tokens
))
def
preprocess_text
(
self
,
inputs
):
def
preprocess_text
(
self
,
inputs
):
if
self
.
remove_space
:
if
self
.
remove_space
:
outputs
=
' '
.
join
(
inputs
.
strip
().
split
())
outputs
=
' '
.
join
(
inputs
.
strip
().
split
())
...
@@ -272,15 +212,9 @@ class XLNetTokenizer(object):
...
@@ -272,15 +212,9 @@ class XLNetTokenizer(object):
""" Converts a sequence of tokens into ids using the vocab. """
""" Converts a sequence of tokens into ids using the vocab. """
ids
=
[]
ids
=
[]
if
isinstance
(
tokens
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
tokens
,
unicode
)):
if
isinstance
(
tokens
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
tokens
,
unicode
)):
if
tokens
in
self
.
special_tokens
:
return
self
.
sp_model
.
PieceToId
(
tokens
)
return
self
.
special_tokens
[
tokens
]
else
:
return
self
.
sp_model
.
PieceToId
(
tokens
)
for
token
in
tokens
:
for
token
in
tokens
:
if
token
in
self
.
special_tokens
:
ids
.
append
(
self
.
sp_model
.
PieceToId
(
token
))
ids
.
append
(
self
.
special_tokens
[
token
])
else
:
ids
.
append
(
self
.
sp_model
.
PieceToId
(
token
))
if
len
(
ids
)
>
self
.
max_len
:
if
len
(
ids
)
>
self
.
max_len
:
logger
.
warning
(
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
"Token indices sequence length is longer than the specified maximum "
...
@@ -289,15 +223,11 @@ class XLNetTokenizer(object):
...
@@ -289,15 +223,11 @@ class XLNetTokenizer(object):
)
)
return
ids
return
ids
def
convert_ids_to_tokens
(
self
,
ids
,
return_unicode
=
True
,
skip_special_tokens
=
False
):
def
convert_ids_to_tokens
(
self
,
ids
,
return_unicode
=
True
):
"""Converts a sequence of ids in tokens."""
"""Converts a sequence of ids in tokens."""
tokens
=
[]
tokens
=
[]
for
i
in
ids
:
for
i
in
ids
:
if
i
in
self
.
special_tokens_decoder
:
tokens
.
append
(
self
.
sp_model
.
IdToPiece
(
i
))
if
not
skip_special_tokens
:
tokens
.
append
(
self
.
special_tokens_decoder
[
i
])
else
:
tokens
.
append
(
self
.
sp_model
.
IdToPiece
(
i
))
if
six
.
PY2
and
return_unicode
:
if
six
.
PY2
and
return_unicode
:
ret_pieces
=
[]
ret_pieces
=
[]
...
@@ -311,9 +241,9 @@ class XLNetTokenizer(object):
...
@@ -311,9 +241,9 @@ class XLNetTokenizer(object):
def
encode
(
self
,
text
,
sample
=
False
):
def
encode
(
self
,
text
,
sample
=
False
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
sample
=
sample
))
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
sample
=
sample
))
def
decode
(
self
,
ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
def
decode
(
self
,
ids
,
clean_up_tokenization_spaces
=
True
):
"""Converts a sequence of ids in a string."""
"""Converts a sequence of ids in a string."""
tokens
=
self
.
convert_ids_to_tokens
(
ids
,
skip_special_tokens
=
skip_special_tokens
)
tokens
=
self
.
convert_ids_to_tokens
(
ids
)
out_string
=
''
.
join
(
tokens
)
out_string
=
''
.
join
(
tokens
)
if
clean_up_tokenization_spaces
:
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
strip
().
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
strip
().
replace
(
'<unk>'
,
''
)
...
@@ -328,18 +258,7 @@ class XLNetTokenizer(object):
...
@@ -328,18 +258,7 @@ class XLNetTokenizer(object):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
return
out_vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
out_vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
index
=
len
(
self
.
sp_model
)
return
(
out_vocab_file
,)
with
open
(
special_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
special_tokens
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
special_tokens_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
out_vocab_file
,
special_tokens_file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment