Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e65f255
Commit
3e65f255
authored
Apr 15, 2019
by
thomwolf
Browse files
add serialization semantics to tokenizers - fix transfo-xl tokenizer
parent
61674333
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
110 deletions
+67
-110
examples/run_transfo_xl.py
examples/run_transfo_xl.py
+2
-1
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+13
-0
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+16
-0
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+16
-0
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+20
-109
No files found.
examples/run_transfo_xl.py
View file @
3e65f255
...
@@ -28,7 +28,7 @@ import math
...
@@ -28,7 +28,7 @@ import math
import
torch
import
torch
from
pytorch_pretrained_bert
import
TransfoXLLMHeadModel
,
TransfoXLCorpus
from
pytorch_pretrained_bert
import
TransfoXLLMHeadModel
,
TransfoXLCorpus
,
TransfoXLTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
@@ -80,6 +80,7 @@ def main():
...
@@ -80,6 +80,7 @@ def main():
# The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
# The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
# and tokenizing the dataset
# and tokenizing the dataset
# The pre-processed corpus is a convertion (using the conversion script )
# The pre-processed corpus is a convertion (using the conversion script )
tokenizer
=
TransfoXLTokenizer
.
from_pretrained
(
args
.
model_name
)
corpus
=
TransfoXLCorpus
.
from_pretrained
(
args
.
model_name
)
corpus
=
TransfoXLCorpus
.
from_pretrained
(
args
.
model_name
)
ntokens
=
len
(
corpus
.
vocab
)
ntokens
=
len
(
corpus
.
vocab
)
...
...
pytorch_pretrained_bert/tokenization.py
View file @
3e65f255
...
@@ -134,6 +134,19 @@ class BertTokenizer(object):
...
@@ -134,6 +134,19 @@ class BertTokenizer(object):
tokens
.
append
(
self
.
ids_to_tokens
[
i
])
tokens
.
append
(
self
.
ids_to_tokens
[
i
])
return
tokens
return
tokens
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
index
=
0
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
with
open
(
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
vocab
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
.
format
(
vocab_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
"""
...
...
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
3e65f255
...
@@ -187,6 +187,22 @@ class GPT2Tokenizer(object):
...
@@ -187,6 +187,22 @@ class GPT2Tokenizer(object):
self
.
cache
[
token
]
=
word
self
.
cache
[
token
]
=
word
return
word
return
word
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
index
+=
1
def
encode
(
self
,
text
):
def
encode
(
self
,
text
):
bpe_tokens
=
[]
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
for
token
in
re
.
findall
(
self
.
pat
,
text
):
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
3e65f255
...
@@ -261,3 +261,19 @@ class OpenAIGPTTokenizer(object):
...
@@ -261,3 +261,19 @@ class OpenAIGPTTokenizer(object):
).
replace
(
" 's"
,
"'s"
).
replace
(
" t "
,
"'t "
).
replace
(
" s "
,
"'s "
).
replace
(
" m "
,
"'m "
).
replace
(
" 's"
,
"'s"
).
replace
(
" t "
,
"'t "
).
replace
(
" s "
,
"'s "
).
replace
(
" m "
,
"'m "
).
replace
(
" 've"
,
"'ve"
)
).
replace
(
" 've"
,
"'ve"
)
return
out_string
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
index
+=
1
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
3e65f255
...
@@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
...
@@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
else
:
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
else
:
vocab_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
...
@@ -141,6 +144,11 @@ class TransfoXLTokenizer(object):
...
@@ -141,6 +144,11 @@ class TransfoXLTokenizer(object):
else
:
else
:
raise
ValueError
(
'No <unkown> token in vocabulary'
)
raise
ValueError
(
'No <unkown> token in vocabulary'
)
def
save_vocabulary
(
self
,
vocab_path
):
index
=
0
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
def
build_vocab
(
self
):
def
build_vocab
(
self
):
if
self
.
vocab_file
:
if
self
.
vocab_file
:
print
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
print
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
...
@@ -245,82 +253,24 @@ class TransfoXLTokenizer(object):
...
@@ -245,82 +253,24 @@ class TransfoXLTokenizer(object):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
idx2sym
)
return
len
(
self
.
idx2sym
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
if
text
in
self
.
never_split
:
return
[
text
]
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
whitespace_tokenize
(
self
,
text
):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
if
self
.
delimiter
==
''
:
tokens
=
text
else
:
tokens
=
text
.
split
(
self
.
delimiter
)
return
tokens
def
tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
def
tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
line
=
self
.
_clean_text
(
line
)
line
=
line
.
strip
()
line
=
line
.
strip
()
# convert to lower case
if
self
.
lower_case
:
line
=
line
.
lower
()
symbols
=
self
.
whitespace_tokenize
(
line
)
# empty delimiter '' will evaluate False
if
self
.
delimiter
==
''
:
split_symbols
=
[]
symbols
=
line
for
symbol
in
symbols
:
else
:
if
self
.
lower_case
and
symbol
not
in
self
.
never_split
:
symbols
=
line
.
split
(
self
.
delimiter
)
symbol
=
symbol
.
lower
()
symbol
=
self
.
_run_strip_accents
(
symbol
)
split_symbols
.
extend
(
self
.
_run_split_on_punc
(
symbol
))
if
add_double_eos
:
# lm1b
if
add_double_eos
:
# lm1b
return
[
'<S>'
]
+
split_
symbols
+
[
'<S>'
]
return
[
'<S>'
]
+
symbols
+
[
'<S>'
]
elif
add_eos
:
elif
add_eos
:
return
split_
symbols
+
[
'<eos>'
]
return
symbols
+
[
'<eos>'
]
else
:
else
:
return
split_
symbols
return
symbols
class
LMOrderedIterator
(
object
):
class
LMOrderedIterator
(
object
):
...
@@ -631,42 +581,3 @@ def get_lm_corpus(datadir, dataset):
...
@@ -631,42 +581,3 @@ def get_lm_corpus(datadir, dataset):
torch
.
save
(
corpus
,
fn
)
torch
.
save
(
corpus
,
fn
)
return
corpus
return
corpus
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment