Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
175fce0a
Unverified
Commit
175fce0a
authored
Jul 05, 2019
by
Thomas Wolf
Committed by
GitHub
Jul 05, 2019
Browse files
Merge pull request #758 from huggingface/doc
Release 0.7 - Add tokenizer API + tests
parents
cf86d23e
e75c3f70
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
152 additions
and
107 deletions
+152
-107
docs/index.rst
docs/index.rst
+2
-0
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+6
-0
pytorch_pretrained_bert/modeling_bert.py
pytorch_pretrained_bert/modeling_bert.py
+1
-0
pytorch_pretrained_bert/tests/tokenization_bert_test.py
pytorch_pretrained_bert/tests/tokenization_bert_test.py
+4
-13
pytorch_pretrained_bert/tests/tokenization_gpt2_test.py
pytorch_pretrained_bert/tests/tokenization_gpt2_test.py
+3
-12
pytorch_pretrained_bert/tests/tokenization_openai_test.py
pytorch_pretrained_bert/tests/tokenization_openai_test.py
+4
-12
pytorch_pretrained_bert/tests/tokenization_tests_commons.py
pytorch_pretrained_bert/tests/tokenization_tests_commons.py
+81
-0
pytorch_pretrained_bert/tests/tokenization_transfo_xl_test.py
...rch_pretrained_bert/tests/tokenization_transfo_xl_test.py
+3
-12
pytorch_pretrained_bert/tests/tokenization_xlm_test.py
pytorch_pretrained_bert/tests/tokenization_xlm_test.py
+3
-12
pytorch_pretrained_bert/tests/tokenization_xlnet_test.py
pytorch_pretrained_bert/tests/tokenization_xlnet_test.py
+5
-29
pytorch_pretrained_bert/tokenization_bert.py
pytorch_pretrained_bert/tokenization_bert.py
+15
-1
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+3
-3
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+2
-3
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+16
-4
pytorch_pretrained_bert/tokenization_xlm.py
pytorch_pretrained_bert/tokenization_xlm.py
+2
-3
pytorch_pretrained_bert/tokenization_xlnet.py
pytorch_pretrained_bert/tokenization_xlnet.py
+2
-3
No files found.
docs/index.rst
0 → 100644
View file @
175fce0a
Home
====
pytorch_pretrained_bert/model_utils.py
View file @
175fce0a
...
...
@@ -598,3 +598,9 @@ def prune_layer(layer, index, dim=None):
return
prune_conv1d_layer
(
layer
,
index
,
dim
=
1
if
dim
is
None
else
dim
)
else
:
raise
ValueError
(
"Can't prune layer of class {}"
.
format
(
layer
.
__class__
))
def
clean_up_tokenization
(
out_string
):
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
out_string
pytorch_pretrained_bert/modeling_bert.py
View file @
175fce0a
...
...
@@ -48,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin"
,
'bert-base-cased-finetuned-mrpc'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin"
,
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json"
,
...
...
pytorch_pretrained_bert/tests/tokenization_bert_test.py
View file @
175fce0a
...
...
@@ -26,6 +26,7 @@ from pytorch_pretrained_bert.tokenization_bert import (BasicTokenizer,
_is_control
,
_is_punctuation
,
_is_whitespace
,
PRETRAINED_VOCAB_ARCHIVE_MAP
)
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
class
TokenizationTest
(
unittest
.
TestCase
):
...
...
@@ -36,28 +37,18 @@ class TokenizationTest(unittest.TestCase):
]
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
create_and_check_tokenizer_commons
(
self
,
BertTokenizer
,
vocab_file
)
tokenizer
=
BertTokenizer
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
=
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
...
...
pytorch_pretrained_bert/tests/tokenization_gpt2_test.py
View file @
175fce0a
...
...
@@ -22,6 +22,7 @@ import pytest
from
pytorch_pretrained_bert.tokenization_gpt2
import
GPT2Tokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
class
GPT2TokenizationTest
(
unittest
.
TestCase
):
...
...
@@ -39,10 +40,9 @@ class GPT2TokenizationTest(unittest.TestCase):
fp
.
write
(
"
\n
"
.
join
(
merges
))
merges_file
=
fp
.
name
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
create_and_check_tokenizer_commons
(
self
,
GPT2Tokenizer
,
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er"
]
tokens
=
tokenizer
.
tokenize
(
text
)
...
...
@@ -53,17 +53,8 @@ class GPT2TokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer_2
=
GPT2Tokenizer
.
from_pretrained
(
"/tmp/"
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
os
.
remove
(
special_tokens_file
)
self
.
assertListEqual
(
[
tokenizer
.
encoder
,
tokenizer
.
decoder
,
tokenizer
.
bpe_ranks
,
tokenizer
.
special_tokens
,
tokenizer
.
special_tokens_decoder
],
[
tokenizer_2
.
encoder
,
tokenizer_2
.
decoder
,
tokenizer_2
.
bpe_ranks
,
tokenizer_2
.
special_tokens
,
tokenizer_2
.
special_tokens_decoder
])
# @pytest.mark.slow
def
test_tokenizer_from_pretrained
(
self
):
...
...
pytorch_pretrained_bert/tests/tokenization_openai_test.py
View file @
175fce0a
...
...
@@ -22,6 +22,8 @@ import pytest
from
pytorch_pretrained_bert.tokenization_openai
import
OpenAIGPTTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
class
OpenAIGPTTokenizationTest
(
unittest
.
TestCase
):
...
...
@@ -40,6 +42,8 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
fp
.
write
(
"
\n
"
.
join
(
merges
))
merges_file
=
fp
.
name
create_and_check_tokenizer_commons
(
self
,
OpenAIGPTTokenizer
,
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
tokenizer
=
OpenAIGPTTokenizer
(
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
...
...
@@ -54,18 +58,6 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer_2
=
OpenAIGPTTokenizer
.
from_pretrained
(
"/tmp/"
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
os
.
remove
(
special_tokens_file
)
self
.
assertListEqual
(
[
tokenizer
.
encoder
,
tokenizer
.
decoder
,
tokenizer
.
bpe_ranks
,
tokenizer
.
special_tokens
,
tokenizer
.
special_tokens_decoder
],
[
tokenizer_2
.
encoder
,
tokenizer_2
.
decoder
,
tokenizer_2
.
bpe_ranks
,
tokenizer_2
.
special_tokens
,
tokenizer_2
.
special_tokens_decoder
])
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
...
...
pytorch_pretrained_bert/tests/tokenization_tests_commons.py
0 → 100644
View file @
175fce0a
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
from
io
import
open
if
sys
.
version_info
[
0
]
==
3
:
unicode
=
str
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
def
create_and_check_save_and_load_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
tokenizer
=
tokenizer_class
(
*
inputs
,
**
kwargs
)
before_tokens
=
tokenizer
.
encode
(
u
"He is very happy, UNwant
\u00E9
d,running"
)
vocab_path
=
"/tmp/"
output_files
=
tokenizer
.
save_vocabulary
(
vocab_path
=
vocab_path
)
tokenizer
=
tokenizer
.
from_pretrained
(
vocab_path
)
for
f
in
output_files
:
os
.
remove
(
f
)
after_tokens
=
tokenizer
.
encode
(
u
"He is very happy, UNwant
\u00E9
d,running"
)
tester
.
assertListEqual
(
before_tokens
,
after_tokens
)
def
create_and_check_pickle_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
tokenizer
=
tokenizer_class
(
*
inputs
,
**
kwargs
)
text
=
"Munich and Berlin are nice cities"
filename
=
u
"/tmp/tokenizer.bin"
subwords
=
tokenizer
.
tokenize
(
text
)
pickle
.
dump
(
tokenizer
,
open
(
filename
,
"wb"
))
tokenizer_new
=
pickle
.
load
(
open
(
filename
,
"rb"
))
subwords_loaded
=
tokenizer_new
.
tokenize
(
text
)
tester
.
assertListEqual
(
subwords
,
subwords_loaded
)
def
create_and_check_required_methods_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
tokenizer
=
tokenizer_class
(
*
inputs
,
**
kwargs
)
text
=
u
"He is very happy, UNwant
\u00E9
d,running"
tokens
=
tokenizer
.
tokenize
(
text
)
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
ids_2
=
tokenizer
.
encode
(
text
)
tester
.
assertListEqual
(
ids
,
ids_2
)
tokens_2
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
text_2
=
tokenizer
.
decode
(
ids
)
tester
.
assertNotEqual
(
len
(
tokens_2
),
0
)
tester
.
assertIsInstance
(
text_2
,
(
str
,
unicode
))
def
create_and_check_tokenizer_commons
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
create_and_check_required_methods_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
)
create_and_check_save_and_load_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
)
create_and_check_pickle_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
)
pytorch_pretrained_bert/tests/tokenization_transfo_xl_test.py
View file @
175fce0a
...
...
@@ -22,6 +22,7 @@ import pytest
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
TransfoXLTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
class
TransfoXLTokenizationTest
(
unittest
.
TestCase
):
...
...
@@ -33,18 +34,9 @@ class TransfoXLTokenizationTest(unittest.TestCase):
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
tokenizer
=
TransfoXLTokenizer
(
vocab_file
=
vocab_file
,
lower_case
=
True
)
tokenizer
.
build_vocab
()
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwanted , running"
)
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
create_and_check_tokenizer_commons
(
self
,
TransfoXLTokenizer
,
vocab_file
=
vocab_file
,
lower_case
=
True
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
=
tokenizer
.
from_pretrained
(
vocab_file
)
tokenizer
=
TransfoXLTokenizer
(
vocab_file
=
vocab_file
,
lower_case
=
True
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwanted , running"
)
...
...
@@ -53,7 +45,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
def
test_full_tokenizer_lower
(
self
):
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
True
)
...
...
pytorch_pretrained_bert/tests/tokenization_xlm_test.py
View file @
175fce0a
...
...
@@ -22,6 +22,7 @@ import pytest
from
pytorch_pretrained_bert.tokenization_xlm
import
XLMTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
class
XLMTokenizationTest
(
unittest
.
TestCase
):
...
...
@@ -40,6 +41,8 @@ class XLMTokenizationTest(unittest.TestCase):
fp
.
write
(
"
\n
"
.
join
(
merges
))
merges_file
=
fp
.
name
create_and_check_tokenizer_commons
(
self
,
XLMTokenizer
,
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
tokenizer
=
XLMTokenizer
(
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
...
...
@@ -54,18 +57,6 @@ class XLMTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer_2
=
XLMTokenizer
.
from_pretrained
(
"/tmp/"
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
os
.
remove
(
special_tokens_file
)
self
.
assertListEqual
(
[
tokenizer
.
encoder
,
tokenizer
.
decoder
,
tokenizer
.
bpe_ranks
,
tokenizer
.
special_tokens
,
tokenizer
.
special_tokens_decoder
],
[
tokenizer_2
.
encoder
,
tokenizer_2
.
decoder
,
tokenizer_2
.
bpe_ranks
,
tokenizer_2
.
special_tokens
,
tokenizer_2
.
special_tokens_decoder
])
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
...
...
pytorch_pretrained_bert/tests/tokenization_xlnet_test.py
View file @
175fce0a
...
...
@@ -15,28 +15,25 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
sys
import
unittest
from
io
import
open
import
shutil
import
pytest
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
from
pytorch_pretrained_bert.tokenization_xlnet
import
(
XLNetTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
,
SPIECE_UNDERLINE
)
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'fixtures/test_sentencepiece.model'
)
class
XLNetTokenizationTest
(
unittest
.
TestCase
):
def
test_full_tokenizer
(
self
):
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
)
create_and_check_tokenizer_commons
(
self
,
XLNetTokenizer
,
SAMPLE_VOCAB
)
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
keep_accents
=
True
)
tokens
=
tokenizer
.
tokenize
(
u
'This is a test'
)
self
.
assertListEqual
(
tokens
,
[
u
'▁This'
,
u
'▁is'
,
u
'▁a'
,
u
'▁t'
,
u
'est'
])
...
...
@@ -44,11 +41,6 @@ class XLNetTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
285
,
46
,
10
,
170
,
382
])
vocab_path
=
u
"/tmp/"
vocab_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
)
tokenizer
=
tokenizer
.
from_pretrained
(
vocab_path
,
keep_accents
=
True
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
u
'I'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
u
'or'
,
u
'n'
,
SPIECE_UNDERLINE
+
u
'in'
,
SPIECE_UNDERLINE
+
u
''
,
...
...
@@ -68,22 +60,6 @@ class XLNetTokenizationTest(unittest.TestCase):
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
u
'<unk>'
,
u
'.'
])
text
=
"Munich and Berlin are nice cities"
filename
=
u
"/tmp/tokenizer.bin"
subwords
=
tokenizer
.
tokenize
(
text
)
pickle
.
dump
(
tokenizer
,
open
(
filename
,
"wb"
))
tokenizer_new
=
pickle
.
load
(
open
(
filename
,
"rb"
))
subwords_loaded
=
tokenizer_new
.
tokenize
(
text
)
self
.
assertListEqual
(
subwords
,
subwords_loaded
)
os
.
remove
(
filename
)
os
.
remove
(
vocab_file
)
os
.
remove
(
special_tokens_file
)
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
...
...
pytorch_pretrained_bert/tokenization_bert.py
View file @
175fce0a
...
...
@@ -23,6 +23,7 @@ import unicodedata
from
io
import
open
from
.file_utils
import
cached_path
from
.model_utils
import
clean_up_tokenization
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -185,6 +186,19 @@ class BertTokenizer(object):
tokens
.
append
(
self
.
ids_to_tokens
[
i
])
return
tokens
def
encode
(
self
,
text
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
decode
(
self
,
token_ids
,
clean_up_tokenization_spaces
=
True
):
"""Converts a sequence of ids in a string."""
tokens
=
self
.
convert_ids_to_tokens
(
token_ids
)
out_string
=
''
.
join
(
tokens
).
replace
(
' ##'
,
''
).
strip
()
if
clean_up_tokenization_spaces
:
for
special_tok
in
(
self
.
UNK_TOKEN
,
self
.
SEP_TOKEN
,
self
.
PAD_TOKEN
,
self
.
CLS_TOKEN
,
self
.
MASK_TOKEN
):
out_string
=
out_string
.
replace
(
special_tok
,
''
)
out_string
=
clean_up_tokenization
(
out_string
)
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a directory or file."""
index
=
0
...
...
@@ -198,7 +212,7 @@ class BertTokenizer(object):
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
vocab_file
return
(
vocab_file
,)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
...
...
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
175fce0a
...
...
@@ -23,6 +23,8 @@ import os
import
regex
as
re
from
io
import
open
from
.model_utils
import
clean_up_tokenization
try
:
from
functools
import
lru_cache
except
ImportError
:
...
...
@@ -275,9 +277,7 @@ class GPT2Tokenizer(object):
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
if
clean_up_tokenization_spaces
:
text
=
text
.
replace
(
'<unk>'
,
''
)
text
=
text
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
text
=
clean_up_tokenization
(
text
)
return
text
def
save_vocabulary
(
self
,
vocab_path
):
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
175fce0a
...
...
@@ -26,6 +26,7 @@ from io import open
from
tqdm
import
tqdm
from
.file_utils
import
cached_path
from
.model_utils
import
clean_up_tokenization
from
.tokenization_bert
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -277,9 +278,7 @@ class OpenAIGPTTokenizer(object):
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
out_string
=
clean_up_tokenization
(
out_string
)
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
...
...
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
175fce0a
...
...
@@ -31,6 +31,7 @@ import torch
import
numpy
as
np
from
.file_utils
import
cached_path
from
.model_utils
import
clean_up_tokenization
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
...
...
@@ -109,6 +110,9 @@ class TransfoXLTokenizer(object):
self
.
vocab_file
=
vocab_file
self
.
never_split
=
never_split
if
vocab_file
is
not
None
:
self
.
build_vocab
()
def
count_file
(
self
,
path
,
verbose
=
False
,
add_eos
=
False
):
if
verbose
:
print
(
'counting file {} ...'
.
format
(
path
))
assert
os
.
path
.
exists
(
path
)
...
...
@@ -155,7 +159,7 @@ class TransfoXLTokenizer(object):
if
os
.
path
.
isdir
(
vocab_path
):
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
return
vocab_file
return
(
vocab_file
,)
def
build_vocab
(
self
):
if
self
.
vocab_file
:
...
...
@@ -251,12 +255,20 @@ class TransfoXLTokenizer(object):
def
convert_to_tensor
(
self
,
symbols
):
return
torch
.
LongTensor
(
self
.
convert_tokens_to_ids
(
symbols
))
def
decode
(
self
,
indices
,
exclude
=
None
):
def
encode
(
self
,
text
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
decode
(
self
,
indices
,
exclude
=
None
,
clean_up_tokenization_spaces
=
True
):
"""Converts a sequence of indices in a string."""
if
exclude
is
None
:
return
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
])
out_string
=
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
])
else
:
return
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
if
idx
not
in
exclude
])
out_string
=
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
if
idx
not
in
exclude
])
if
clean_up_tokenization_spaces
:
out_string
=
clean_up_tokenization
(
out_string
)
return
out_string
def
__len__
(
self
):
return
len
(
self
.
idx2sym
)
...
...
pytorch_pretrained_bert/tokenization_xlm.py
View file @
175fce0a
...
...
@@ -26,6 +26,7 @@ from io import open
from
tqdm
import
tqdm
from
.file_utils
import
cached_path
from
.model_utils
import
clean_up_tokenization
from
.tokenization_bert
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -285,9 +286,7 @@ class XLMTokenizer(object):
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
out_string
=
clean_up_tokenization
(
out_string
)
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
...
...
pytorch_pretrained_bert/tokenization_xlnet.py
View file @
175fce0a
...
...
@@ -27,6 +27,7 @@ import unicodedata
import
six
from
.file_utils
import
cached_path
from
.model_utils
import
clean_up_tokenization
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -316,9 +317,7 @@ class XLNetTokenizer(object):
out_string
=
''
.
join
(
tokens
)
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
strip
().
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
out_string
=
clean_up_tokenization
(
out_string
)
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment