Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e65f255
Commit
3e65f255
authored
Apr 15, 2019
by
thomwolf
Browse files
add serialization semantics to tokenizers - fix transfo-xl tokenizer
parent
61674333
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
110 deletions
+67
-110
examples/run_transfo_xl.py
examples/run_transfo_xl.py
+2
-1
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+13
-0
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+16
-0
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+16
-0
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+20
-109
No files found.
examples/run_transfo_xl.py
View file @
3e65f255
...
...
@@ -28,7 +28,7 @@ import math
import
torch
from
pytorch_pretrained_bert
import
TransfoXLLMHeadModel
,
TransfoXLCorpus
from
pytorch_pretrained_bert
import
TransfoXLLMHeadModel
,
TransfoXLCorpus
,
TransfoXLTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
...
@@ -80,6 +80,7 @@ def main():
# The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
# and tokenizing the dataset
# The pre-processed corpus is a convertion (using the conversion script )
tokenizer
=
TransfoXLTokenizer
.
from_pretrained
(
args
.
model_name
)
corpus
=
TransfoXLCorpus
.
from_pretrained
(
args
.
model_name
)
ntokens
=
len
(
corpus
.
vocab
)
...
...
pytorch_pretrained_bert/tokenization.py
View file @
3e65f255
...
...
@@ -134,6 +134,19 @@ class BertTokenizer(object):
tokens
.
append
(
self
.
ids_to_tokens
[
i
])
return
tokens
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
index
=
0
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
with
open
(
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
vocab
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
.
format
(
vocab_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
...
...
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
3e65f255
...
...
@@ -187,6 +187,22 @@ class GPT2Tokenizer(object):
self
.
cache
[
token
]
=
word
return
word
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
index
+=
1
def
encode
(
self
,
text
):
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
3e65f255
...
...
@@ -261,3 +261,19 @@ class OpenAIGPTTokenizer(object):
).
replace
(
" 's"
,
"'s"
).
replace
(
" t "
,
"'t "
).
replace
(
" s "
,
"'s "
).
replace
(
" m "
,
"'m "
).
replace
(
" 've"
,
"'ve"
)
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a path."""
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
json
.
dump
(
self
.
encoder
,
vocab_file
)
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
bpe_tokens
+
u
'
\n
'
)
index
+=
1
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
3e65f255
...
...
@@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
else
:
vocab_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
...
...
@@ -141,6 +144,11 @@ class TransfoXLTokenizer(object):
else
:
raise
ValueError
(
'No <unkown> token in vocabulary'
)
def
save_vocabulary
(
self
,
vocab_path
):
index
=
0
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
def
build_vocab
(
self
):
if
self
.
vocab_file
:
print
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
...
...
@@ -245,82 +253,24 @@ class TransfoXLTokenizer(object):
def
__len__
(
self
):
return
len
(
self
.
idx2sym
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
if
text
in
self
.
never_split
:
return
[
text
]
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
whitespace_tokenize
(
self
,
text
):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
if
self
.
delimiter
==
''
:
tokens
=
text
else
:
tokens
=
text
.
split
(
self
.
delimiter
)
return
tokens
def
tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
line
=
self
.
_clean_text
(
line
)
line
=
line
.
strip
()
# convert to lower case
if
self
.
lower_case
:
line
=
line
.
lower
()
symbols
=
self
.
whitespace_tokenize
(
line
)
split_symbols
=
[]
for
symbol
in
symbols
:
if
self
.
lower_case
and
symbol
not
in
self
.
never_split
:
symbol
=
symbol
.
lower
()
symbol
=
self
.
_run_strip_accents
(
symbol
)
split_symbols
.
extend
(
self
.
_run_split_on_punc
(
symbol
))
# empty delimiter '' will evaluate False
if
self
.
delimiter
==
''
:
symbols
=
line
else
:
symbols
=
line
.
split
(
self
.
delimiter
)
if
add_double_eos
:
# lm1b
return
[
'<S>'
]
+
split_
symbols
+
[
'<S>'
]
return
[
'<S>'
]
+
symbols
+
[
'<S>'
]
elif
add_eos
:
return
split_
symbols
+
[
'<eos>'
]
return
symbols
+
[
'<eos>'
]
else
:
return
split_
symbols
return
symbols
class
LMOrderedIterator
(
object
):
...
...
@@ -631,42 +581,3 @@ def get_lm_corpus(datadir, dataset):
torch
.
save
(
corpus
,
fn
)
return
corpus
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment