Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
288be7b7
Commit
288be7b7
authored
Jul 02, 2019
by
thomwolf
Browse files
xlm
parent
70887795
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
547 additions
and
205 deletions
+547
-205
pytorch_pretrained_bert/convert_xlm_checkpoint_to_pytorch.py
pytorch_pretrained_bert/convert_xlm_checkpoint_to_pytorch.py
+73
-0
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+148
-205
pytorch_pretrained_bert/tokenization_xlm.py
pytorch_pretrained_bert/tokenization_xlm.py
+326
-0
No files found.
pytorch_pretrained_bert/convert_xlm_checkpoint_to_pytorch.py
0 → 100755
View file @
288be7b7
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
json
from
io
import
open
import
torch
import
numpy
from
pytorch_pretrained_bert.modeling_xlm
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
XLMConfig
,
XLMModel
)
from
pytorch_pretrained_bert.tokenization_xlm
import
MERGES_NAME
,
VOCAB_NAME
def
convert_xlm_checkpoint_to_pytorch
(
xlm_checkpoint_path
,
pytorch_dump_folder_path
):
# Load checkpoint
chkpt
=
torch
.
load
(
xlm_checkpoint_path
,
map_location
=
'cpu'
)
model
=
chkpt
[
'model'
]
config
=
chkpt
[
'params'
]
config
=
dict
((
n
,
v
)
for
n
,
v
in
config
.
items
()
if
not
isinstance
(
v
,
(
torch
.
Tensor
,
numpy
.
ndarray
)))
vocab
=
chkpt
[
'dico_word2id'
]
vocab
=
dict
((
s
+
'</w>'
if
s
.
find
(
'@@'
)
==
-
1
and
i
>
13
else
s
.
replace
(
'@@'
,
''
),
i
)
for
s
,
i
in
d
.
items
())
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_NAME
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
,
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
config
,
indent
=
2
)
+
"
\n
"
)
print
(
"Save vocab file to {}"
.
format
(
pytorch_config_dump_path
))
with
open
(
pytorch_vocab_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
vocab
,
indent
=
2
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--xlm_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the official PyTorch dump."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_xlm_checkpoint_to_pytorch
(
args
.
xlm_checkpoint_path
,
args
.
pytorch_dump_folder_path
)
pytorch_pretrained_bert/modeling_xlm.py
View file @
288be7b7
This diff is collapsed.
Click to expand it.
pytorch_pretrained_bert/tokenization_xlm.py
0 → 100644
View file @
288be7b7
# coding=utf-8
# Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
json
import
logging
import
os
import
re
import
sys
from
io
import
open
from
tqdm
import
tqdm
from
.file_utils
import
cached_path
from
.tokenization
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json"
,
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt"
,
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'xlm-mlm-en-2048'
:
512
,
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
INDEX
=
{
"bos_index"
:
0
,
"eos_index"
:
1
,
"pad_index"
:
2
,
"unk_index"
:
3
,
"mask_index"
:
5
}
def
get_pairs
(
word
):
"""
Return set of symbol pairs in a word.
word is represented as tuple of symbols (symbols being variable-length strings)
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
def
text_standardize
(
text
):
"""
fixes some issues the spacy tokenizer had on books corpus
also does some whitespace standardization
"""
text
=
text
.
replace
(
'—'
,
'-'
)
text
=
text
.
replace
(
'–'
,
'-'
)
text
=
text
.
replace
(
'―'
,
'-'
)
text
=
text
.
replace
(
'…'
,
'...'
)
text
=
text
.
replace
(
'´'
,
"'"
)
text
=
re
.
sub
(
r
'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)'''
,
r
' \1 '
,
text
)
text
=
re
.
sub
(
r
'\s*\n\s*'
,
'
\n
'
,
text
)
text
=
re
.
sub
(
r
'[^\S\n]+'
,
' '
,
text
)
return
text
.
strip
()
class
XLMTokenizer
(
object
):
"""
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens
=
None
,
max_len
=
None
):
try
:
import
ftfy
import
spacy
self
.
nlp
=
spacy
.
load
(
'en'
,
disable
=
[
'parser'
,
'tagger'
,
'ner'
,
'textcat'
])
self
.
fix_text
=
ftfy
.
fix_text
except
ImportError
:
logger
.
warning
(
"ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy."
)
self
.
nlp
=
BasicTokenizer
(
do_lower_case
=
True
,
never_split
=
special_tokens
if
special_tokens
is
not
None
else
[])
self
.
fix_text
=
None
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
,
encoding
=
"utf-8"
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
merges
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
merges
=
[
tuple
(
merge
.
split
()[:
2
])
for
merge
in
merges
]
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
set_special_tokens
(
special_tokens
)
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
def
set_special_tokens
(
self
,
special_tokens
):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if
not
special_tokens
:
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
if
self
.
fix_text
is
None
:
# Using BERT's BasicTokenizer: we can update the tokenizer
self
.
nlp
.
never_split
=
special_tokens
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
'</w>'
,)
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
+
'</w>'
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
' '
.
join
(
word
)
if
word
==
'
\n
</w>'
:
word
=
'
\n
</w>'
self
.
cache
[
token
]
=
word
return
word
def
tokenize
(
self
,
text
):
""" Tokenize a string. """
split_tokens
=
[]
if
self
.
fix_text
is
None
:
# Using BERT's BasicTokenizer
text
=
self
.
nlp
.
tokenize
(
text
)
for
token
in
text
:
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
).
split
(
' '
)])
else
:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text
=
self
.
nlp
(
text_standardize
(
self
.
fix_text
(
text
)))
for
token
in
text
:
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
.
text
.
lower
()).
split
(
' '
)])
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
""" Converts a sequence of tokens into ids using the vocab. """
ids
=
[]
if
isinstance
(
tokens
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
tokens
,
unicode
)):
if
tokens
in
self
.
special_tokens
:
return
self
.
special_tokens
[
tokens
]
else
:
return
self
.
encoder
.
get
(
tokens
,
0
)
for
token
in
tokens
:
if
token
in
self
.
special_tokens
:
ids
.
append
(
self
.
special_tokens
[
token
])
else
:
ids
.
append
(
self
.
encoder
.
get
(
token
,
0
))
if
len
(
ids
)
>
self
.
max_len
:
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
return
ids
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens
=
[]
for
i
in
ids
:
if
i
in
self
.
special_tokens_decoder
:
if
not
skip_special_tokens
:
tokens
.
append
(
self
.
special_tokens_decoder
[
i
])
else
:
tokens
.
append
(
self
.
decoder
[
i
])
return
tokens
def
encode
(
self
,
text
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
decode
(
self
,
ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
"""Converts a sequence of ids in a string."""
tokens
=
self
.
convert_ids_to_tokens
(
ids
,
skip_special_tokens
=
skip_special_tokens
)
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
index
=
len
(
self
.
encoder
)
with
open
(
special_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
special_tokens
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
special_tokens_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
,
special_tokens_file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment