Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
34ccc8eb
Commit
34ccc8eb
authored
Apr 21, 2019
by
lukovnikov
Browse files
Merge remote-tracking branch 'upstream/master'
parents
fc7693ad
68a889ee
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
325 additions
and
155 deletions
+325
-155
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+55
-5
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+23
-109
tests/conftest.py
tests/conftest.py
+19
-0
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+20
-1
tests/modeling_openai_test.py
tests/modeling_openai_test.py
+20
-1
tests/modeling_test.py
tests/modeling_test.py
+20
-0
tests/modeling_transfo_xl_test.py
tests/modeling_transfo_xl_test.py
+20
-1
tests/tokenization_gpt2_test.py
tests/tokenization_gpt2_test.py
+77
-0
tests/tokenization_openai_test.py
tests/tokenization_openai_test.py
+26
-3
tests/tokenization_test.py
tests/tokenization_test.py
+21
-1
tests/tokenization_transfo_xl_test.py
tests/tokenization_transfo_xl_test.py
+24
-34
No files found.
pytorch_pretrained_bert/tokenization_openai.py
View file @
34ccc8eb
...
...
@@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
def
get_pairs
(
word
):
"""
...
...
@@ -86,9 +87,15 @@ class OpenAIGPTTokenizer(object):
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
...
...
@@ -117,7 +124,11 @@ class OpenAIGPTTokenizer(object):
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
*
inputs
,
**
kwargs
)
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens
=
None
,
max_len
=
None
):
...
...
@@ -139,6 +150,8 @@ class OpenAIGPTTokenizer(object):
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
set_special_tokens
(
special_tokens
)
def
__len__
(
self
):
...
...
@@ -250,14 +263,51 @@ class OpenAIGPTTokenizer(object):
tokens
.
append
(
self
.
decoder
[
i
])
return
tokens
def
decode
(
self
,
ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
False
):
def
encode
(
self
,
text
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
decode
(
self
,
ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
"""Converts a sequence of ids in a string."""
tokens
=
self
.
convert_ids_to_tokens
(
ids
,
skip_special_tokens
=
skip_special_tokens
)
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
' ,'
,
','
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" 're"
,
"'re"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" t "
,
"'t "
).
replace
(
" s "
,
"'s "
).
replace
(
" m "
,
"'m "
).
replace
(
" 've"
,
"'ve"
)
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
index
=
len
(
self
.
encoder
)
with
open
(
special_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
special_tokens
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
special_tokens_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
,
special_tokens_file
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
34ccc8eb
...
...
@@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
else
:
vocab_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
...
...
@@ -141,6 +144,14 @@ class TransfoXLTokenizer(object):
else
:
raise
ValueError
(
'No <unkown> token in vocabulary'
)
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a directory or file."""
index
=
0
if
os
.
path
.
isdir
(
vocab_path
):
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
return
vocab_file
def
build_vocab
(
self
):
if
self
.
vocab_file
:
print
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
...
...
@@ -245,82 +256,24 @@ class TransfoXLTokenizer(object):
def
__len__
(
self
):
return
len
(
self
.
idx2sym
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
if
text
in
self
.
never_split
:
return
[
text
]
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
whitespace_tokenize
(
self
,
text
):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
if
self
.
delimiter
==
''
:
tokens
=
text
else
:
tokens
=
text
.
split
(
self
.
delimiter
)
return
tokens
def
tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
line
=
self
.
_clean_text
(
line
)
line
=
line
.
strip
()
# convert to lower case
if
self
.
lower_case
:
line
=
line
.
lower
()
symbols
=
self
.
whitespace_tokenize
(
line
)
split_symbols
=
[]
for
symbol
in
symbols
:
if
self
.
lower_case
and
symbol
not
in
self
.
never_split
:
symbol
=
symbol
.
lower
()
symbol
=
self
.
_run_strip_accents
(
symbol
)
split_symbols
.
extend
(
self
.
_run_split_on_punc
(
symbol
))
# empty delimiter '' will evaluate False
if
self
.
delimiter
==
''
:
symbols
=
line
else
:
symbols
=
line
.
split
(
self
.
delimiter
)
if
add_double_eos
:
# lm1b
return
[
'<S>'
]
+
split_
symbols
+
[
'<S>'
]
return
[
'<S>'
]
+
symbols
+
[
'<S>'
]
elif
add_eos
:
return
split_
symbols
+
[
'<eos>'
]
return
symbols
+
[
'<eos>'
]
else
:
return
split_
symbols
return
symbols
class
LMOrderedIterator
(
object
):
...
...
@@ -631,42 +584,3 @@ def get_lm_corpus(datadir, dataset):
torch
.
save
(
corpus
,
fn
)
return
corpus
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
tests/conftest.py
0 → 100644
View file @
34ccc8eb
# content of conftest.py
import
pytest
def
pytest_addoption
(
parser
):
parser
.
addoption
(
"--runslow"
,
action
=
"store_true"
,
default
=
False
,
help
=
"run slow tests"
)
def
pytest_collection_modifyitems
(
config
,
items
):
if
config
.
getoption
(
"--runslow"
):
# --runslow given in cli: do not skip slow tests
return
skip_slow
=
pytest
.
mark
.
skip
(
reason
=
"need --runslow option to run"
)
for
item
in
items
:
if
"slow"
in
item
.
keywords
:
item
.
add_marker
(
skip_slow
)
tests/modeling_gpt2_test.py
View file @
34ccc8eb
...
...
@@ -16,15 +16,18 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
from
pytorch_pretrained_bert.modeling_gpt2
import
PRETRAINED_MODEL_ARCHIVE_MAP
class
GPT2ModelTest
(
unittest
.
TestCase
):
class
GPT2ModelTester
(
object
):
...
...
@@ -176,6 +179,22 @@ class GPT2ModelTest(unittest.TestCase):
self
.
assertEqual
(
obj
[
"vocab_size"
],
99
)
self
.
assertEqual
(
obj
[
"n_embd"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
GPT2Config
(
vocab_size_or_config_json_file
=
99
,
n_embd
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
GPT2Config
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
GPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_gpt2_model
(
*
config_and_inputs
)
...
...
tests/modeling_openai_test.py
View file @
34ccc8eb
...
...
@@ -16,15 +16,18 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
pytorch_pretrained_bert.modeling_openai
import
PRETRAINED_MODEL_ARCHIVE_MAP
class
OpenAIGPTModelTest
(
unittest
.
TestCase
):
class
OpenAIGPTModelTester
(
object
):
...
...
@@ -188,6 +191,22 @@ class OpenAIGPTModelTest(unittest.TestCase):
self
.
assertEqual
(
obj
[
"vocab_size"
],
99
)
self
.
assertEqual
(
obj
[
"n_embd"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
OpenAIGPTConfig
(
vocab_size_or_config_json_file
=
99
,
n_embd
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
OpenAIGPTConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
OpenAIGPTModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_openai_model
(
*
config_and_inputs
)
...
...
tests/modeling_test.py
View file @
34ccc8eb
...
...
@@ -16,9 +16,12 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
...
...
@@ -26,6 +29,7 @@ from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
)
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
class
BertModelTest
(
unittest
.
TestCase
):
...
...
@@ -251,6 +255,22 @@ class BertModelTest(unittest.TestCase):
self
.
assertEqual
(
obj
[
"vocab_size"
],
99
)
self
.
assertEqual
(
obj
[
"hidden_size"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
BertConfig
(
vocab_size_or_config_json_file
=
99
,
hidden_size
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
BertConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
BertModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_bert_model
(
*
config_and_inputs
)
...
...
tests/modeling_transfo_xl_test.py
View file @
34ccc8eb
...
...
@@ -16,14 +16,17 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
pytorch_pretrained_bert.modeling_transfo_xl
import
PRETRAINED_MODEL_ARCHIVE_MAP
class
TransfoXLModelTest
(
unittest
.
TestCase
):
class
TransfoXLModelTester
(
object
):
...
...
@@ -186,6 +189,22 @@ class TransfoXLModelTest(unittest.TestCase):
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_embed"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
96
,
d_embed
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
TransfoXLConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TransfoXLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
...
...
tests/tokenization_gpt2_test.py
0 → 100644
View file @
34ccc8eb
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
unittest
import
json
import
shutil
import
pytest
from
pytorch_pretrained_bert.tokenization_gpt2
import
GPT2Tokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
class
GPT2TokenizationTest
(
unittest
.
TestCase
):
def
test_full_tokenizer
(
self
):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"lo"
,
"low"
,
"er"
,
"low"
,
"lowest"
,
"newer"
,
"wider"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r"
,
""
]
with
open
(
"/tmp/openai_tokenizer_vocab_test.json"
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
))
vocab_file
=
fp
.
name
with
open
(
"/tmp/openai_tokenizer_merges_test.txt"
,
"w"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
merges_file
=
fp
.
name
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er"
]
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
"<unk>"
]
input_bpe_tokens
=
[
13
,
12
,
16
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer_2
=
GPT2Tokenizer
.
from_pretrained
(
"/tmp/"
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
os
.
remove
(
special_tokens_file
)
self
.
assertListEqual
(
[
tokenizer
.
encoder
,
tokenizer
.
decoder
,
tokenizer
.
bpe_ranks
,
tokenizer
.
special_tokens
,
tokenizer
.
special_tokens_decoder
],
[
tokenizer_2
.
encoder
,
tokenizer_2
.
decoder
,
tokenizer_2
.
bpe_ranks
,
tokenizer_2
.
special_tokens
,
tokenizer_2
.
special_tokens_decoder
])
# @pytest.mark.slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/tokenization_openai_test.py
View file @
34ccc8eb
...
...
@@ -17,8 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
unittest
import
json
import
shutil
import
pytest
from
pytorch_pretrained_bert.tokenization_openai
import
OpenAIGPTTokenizer
from
pytorch_pretrained_bert.tokenization_openai
import
OpenAIGPTTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
class
OpenAIGPTTokenizationTest
(
unittest
.
TestCase
):
...
...
@@ -32,13 +34,13 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r</w>"
,
""
]
with
open
(
"/tmp/openai_tokenizer_vocab_test.json"
,
"w"
)
as
fp
:
json
.
dump
(
vocab_tokens
,
fp
)
fp
.
write
(
json
.
dump
s
(
vocab_tokens
)
)
vocab_file
=
fp
.
name
with
open
(
"/tmp/openai_tokenizer_merges_test.txt"
,
"w"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
merges_file
=
fp
.
name
tokenizer
=
OpenAIGPTTokenizer
(
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
])
tokenizer
=
OpenAIGPTTokenizer
(
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
...
...
@@ -52,5 +54,26 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer_2
=
OpenAIGPTTokenizer
.
from_pretrained
(
"/tmp/"
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
os
.
remove
(
special_tokens_file
)
self
.
assertListEqual
(
[
tokenizer
.
encoder
,
tokenizer
.
decoder
,
tokenizer
.
bpe_ranks
,
tokenizer
.
special_tokens
,
tokenizer
.
special_tokens_decoder
],
[
tokenizer_2
.
encoder
,
tokenizer_2
.
decoder
,
tokenizer_2
.
bpe_ranks
,
tokenizer_2
.
special_tokens
,
tokenizer_2
.
special_tokens_decoder
])
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
OpenAIGPTTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/tokenization_test.py
View file @
34ccc8eb
...
...
@@ -17,12 +17,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
unittest
from
io
import
open
import
shutil
import
pytest
from
pytorch_pretrained_bert.tokenization
import
(
BasicTokenizer
,
BertTokenizer
,
WordpieceTokenizer
,
_is_control
,
_is_punctuation
,
_is_whitespace
)
_is_whitespace
,
PRETRAINED_VOCAB_ARCHIVE_MAP
)
class
TokenizationTest
(
unittest
.
TestCase
):
...
...
@@ -46,6 +48,24 @@ class TokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
BertTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
def
test_chinese
(
self
):
tokenizer
=
BasicTokenizer
()
...
...
tests/tokenization_transfo_xl_test.py
View file @
34ccc8eb
...
...
@@ -17,10 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
unittest
from
io
import
open
import
shutil
import
pytest
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
(
TransfoXLTokenizer
,
_is_control
,
_is_punctuation
,
_is_whitespace
)
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
TransfoXLTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
class
TransfoXLTokenizationTest
(
unittest
.
TestCase
):
...
...
@@ -37,54 +37,44 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer
.
build_vocab
()
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwant
\u00E9
d,
running"
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwant
ed ,
running"
)
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwanted , running"
)
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
def
test_full_tokenizer_lower
(
self
):
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
True
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo
!
how
\n
Are yoU? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo
!
how
\n
Are yoU
? "
),
[
"hello"
,
"!"
,
"how"
,
"are"
,
"you"
,
"?"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
def
test_full_tokenizer_no_lower
(
self
):
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
False
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo
!
how
\n
Are yoU? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo
!
how
\n
Are yoU
? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
def
test_is_whitespace
(
self
):
self
.
assertTrue
(
_is_whitespace
(
u
" "
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\t
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\r
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\n
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\u00A0
"
))
self
.
assertFalse
(
_is_whitespace
(
u
"A"
))
self
.
assertFalse
(
_is_whitespace
(
u
"-"
))
def
test_is_control
(
self
):
self
.
assertTrue
(
_is_control
(
u
"
\u0005
"
))
self
.
assertFalse
(
_is_control
(
u
"A"
))
self
.
assertFalse
(
_is_control
(
u
" "
))
self
.
assertFalse
(
_is_control
(
u
"
\t
"
))
self
.
assertFalse
(
_is_control
(
u
"
\r
"
))
def
test_is_punctuation
(
self
):
self
.
assertTrue
(
_is_punctuation
(
u
"-"
))
self
.
assertTrue
(
_is_punctuation
(
u
"$"
))
self
.
assertTrue
(
_is_punctuation
(
u
"`"
))
self
.
assertTrue
(
_is_punctuation
(
u
"."
))
self
.
assertFalse
(
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
_is_punctuation
(
u
" "
))
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
TransfoXLTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment