Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e8568a3b
Commit
e8568a3b
authored
Apr 15, 2019
by
thomwolf
Browse files
fixing tests
parent
870b734b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
14 deletions
+51
-14
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+23
-4
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+23
-4
tests/tokenization_openai_test.py
tests/tokenization_openai_test.py
+1
-1
tests/tokenization_transfo_xl_test.py
tests/tokenization_transfo_xl_test.py
+4
-5
No files found.
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
e8568a3b
...
@@ -45,6 +45,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
...
@@ -45,6 +45,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
}
}
VOCAB_NAME
=
'vocab.json'
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
@
lru_cache
()
@
lru_cache
()
def
bytes_to_unicode
():
def
bytes_to_unicode
():
...
@@ -97,6 +98,11 @@ class GPT2Tokenizer(object):
...
@@ -97,6 +98,11 @@ class GPT2Tokenizer(object):
else
:
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
...
@@ -125,7 +131,11 @@ class GPT2Tokenizer(object):
...
@@ -125,7 +131,11 @@ class GPT2Tokenizer(object):
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
*
inputs
,
**
kwargs
)
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
max_len
=
None
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
max_len
=
None
):
...
@@ -194,7 +204,11 @@ class GPT2Tokenizer(object):
...
@@ -194,7 +204,11 @@ class GPT2Tokenizer(object):
return
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
index
=
0
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
writer
.
write
(
u
'#version: 0.2
\n
'
)
...
@@ -203,9 +217,14 @@ class GPT2Tokenizer(object):
...
@@ -203,9 +217,14 @@ class GPT2Tokenizer(object):
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
index
+=
1
return
vocab_file
,
merge_file
with
open
(
special_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
writer
:
for
token
in
sorted
(
self
.
special_tokens
.
keys
(),
key
=
lambda
kv
:
kv
[
1
]):
writer
.
write
(
token
+
u
'
\n
'
)
return
vocab_file
,
merge_file
,
special_tokens_file
def
encode
(
self
,
text
):
def
encode
(
self
,
text
):
bpe_tokens
=
[]
bpe_tokens
=
[]
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
e8568a3b
...
@@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
...
@@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
}
}
VOCAB_NAME
=
'vocab.json'
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
def
get_pairs
(
word
):
def
get_pairs
(
word
):
"""
"""
...
@@ -89,6 +90,11 @@ class OpenAIGPTTokenizer(object):
...
@@ -89,6 +90,11 @@ class OpenAIGPTTokenizer(object):
else
:
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
...
@@ -117,7 +123,11 @@ class OpenAIGPTTokenizer(object):
...
@@ -117,7 +123,11 @@ class OpenAIGPTTokenizer(object):
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
*
inputs
,
**
kwargs
)
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens
=
None
,
max_len
=
None
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens
=
None
,
max_len
=
None
):
...
@@ -269,7 +279,11 @@ class OpenAIGPTTokenizer(object):
...
@@ -269,7 +279,11 @@ class OpenAIGPTTokenizer(object):
return
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
index
=
0
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
writer
.
write
(
u
'#version: 0.2
\n
'
)
...
@@ -278,6 +292,11 @@ class OpenAIGPTTokenizer(object):
...
@@ -278,6 +292,11 @@ class OpenAIGPTTokenizer(object):
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
index
+=
1
return
vocab_file
,
merge_file
with
open
(
special_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
writer
:
for
token
in
sorted
(
self
.
special_tokens
.
keys
(),
key
=
lambda
kv
:
kv
[
1
]):
writer
.
write
(
token
+
u
'
\n
'
)
return
vocab_file
,
merge_file
,
special_tokens_file
tests/tokenization_openai_test.py
View file @
e8568a3b
...
@@ -52,7 +52,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
...
@@ -52,7 +52,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
vocab_file
,
merges_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
"/tmp/"
)
tokenizer
.
from_pretrained
(
"/tmp/"
)
os
.
remove
(
vocab_file
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
os
.
remove
(
merges_file
)
...
...
tests/tokenization_transfo_xl_test.py
View file @
e8568a3b
...
@@ -35,7 +35,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
...
@@ -35,7 +35,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer
.
build_vocab
()
tokenizer
.
build_vocab
()
os
.
remove
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwant
\u00E9
d,
running"
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwant
ed ,
running"
)
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
self
.
assertListEqual
(
...
@@ -45,7 +45,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
...
@@ -45,7 +45,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer
.
from_pretrained
(
vocab_file
)
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwant
\u00E9
d,
running"
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwant
ed ,
running"
)
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
self
.
assertListEqual
(
...
@@ -56,15 +56,14 @@ class TransfoXLTokenizationTest(unittest.TestCase):
...
@@ -56,15 +56,14 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
True
)
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
True
)
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo
!
how
\n
Are yoU? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo
!
how
\n
Are yoU
? "
),
[
"hello"
,
"!"
,
"how"
,
"are"
,
"you"
,
"?"
])
[
"hello"
,
"!"
,
"how"
,
"are"
,
"you"
,
"?"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
def
test_full_tokenizer_no_lower
(
self
):
def
test_full_tokenizer_no_lower
(
self
):
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
False
)
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
False
)
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo
!
how
\n
Are yoU? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo
!
how
\n
Are yoU
? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment