Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1b35d05d
"...docs/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "aa59fca5d78e9d4652c0c9c837dc6c465cde0787"
Commit
1b35d05d
authored
Jul 16, 2019
by
thomwolf
Browse files
update conversion scripts and __main__
parent
352e3ff9
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
53 additions
and
20 deletions
+53
-20
pytorch_transformers/__main__.py
pytorch_transformers/__main__.py
+21
-7
pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py
pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py
+4
-1
pytorch_transformers/convert_openai_checkpoint_to_pytorch.py
pytorch_transformers/convert_openai_checkpoint_to_pytorch.py
+4
-1
pytorch_transformers/convert_tf_checkpoint_to_pytorch.py
pytorch_transformers/convert_tf_checkpoint_to_pytorch.py
+4
-5
pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py
..._transformers/convert_transfo_xl_checkpoint_to_pytorch.py
+3
-0
pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py
pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py
+2
-1
pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py
pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py
+7
-2
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+2
-0
pytorch_transformers/tokenization_transfo_xl.py
pytorch_transformers/tokenization_transfo_xl.py
+1
-1
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+2
-1
pytorch_transformers/tokenization_xlnet.py
pytorch_transformers/tokenization_xlnet.py
+3
-1
No files found.
pytorch_transformers/__main__.py
View file @
1b35d05d
# coding: utf8
# coding: utf8
def
main
():
def
main
():
import
sys
import
sys
if
(
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
6
)
or
sys
.
argv
[
1
]
not
in
[
"bert"
,
"gpt"
,
"transfo_xl"
,
"gpt2"
,
"xlnet"
]:
if
(
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
6
)
or
sys
.
argv
[
1
]
not
in
[
"bert"
,
"gpt"
,
"transfo_xl"
,
"gpt2"
,
"xlnet"
,
"xlm"
]:
print
(
print
(
"Should be used as one of:
\n
"
"Should be used as one of:
\n
"
">> `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`,
\n
"
">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT,
\n
"
">> `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`,
\n
"
">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG],
\n
"
">> `pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or
\n
"
">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or
\n
"
">> `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or
\n
"
">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or
\n
"
">> `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`"
)
">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or
\n
"
">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT"
)
else
:
else
:
if
sys
.
argv
[
1
]
==
"bert"
:
if
sys
.
argv
[
1
]
==
"bert"
:
try
:
try
:
...
@@ -86,7 +87,7 @@ def main():
...
@@ -86,7 +87,7 @@ def main():
else
:
else
:
TF_CONFIG
=
""
TF_CONFIG
=
""
convert_gpt2_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
convert_gpt2_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
el
se
:
el
if
sys
.
argv
[
1
]
==
"xlnet"
:
try
:
try
:
from
.convert_xlnet_checkpoint_to_pytorch
import
convert_xlnet_checkpoint_to_pytorch
from
.convert_xlnet_checkpoint_to_pytorch
import
convert_xlnet_checkpoint_to_pytorch
except
ImportError
:
except
ImportError
:
...
@@ -104,11 +105,24 @@ def main():
...
@@ -104,11 +105,24 @@ def main():
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
4
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
4
]
if
len
(
sys
.
argv
)
==
6
:
if
len
(
sys
.
argv
)
==
6
:
FINETUNING_TASK
=
sys
.
argv
[
5
]
FINETUNING_TASK
=
sys
.
argv
[
5
]
else
:
FINETUNING_TASK
=
None
convert_xlnet_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
convert_xlnet_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
PYTORCH_DUMP_OUTPUT
,
FINETUNING_TASK
)
FINETUNING_TASK
)
elif
sys
.
argv
[
1
]
==
"xlm"
:
from
.convert_xlm_checkpoint_to_pytorch
import
convert_xlm_checkpoint_to_pytorch
if
len
(
sys
.
argv
)
!=
4
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`"
)
else
:
XLM_CHECKPOINT_PATH
=
sys
.
argv
[
2
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
convert_xlm_checkpoint_to_pytorch
(
XLM_CHECKPOINT_PATH
,
PYTORCH_DUMP_OUTPUT
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py
View file @
1b35d05d
...
@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
...
@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
GPT2Model
,
GPT2Model
,
load_tf_weights_in_gpt2
)
load_tf_weights_in_gpt2
)
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_gpt2_checkpoint_to_pytorch
(
gpt2_checkpoint_path
,
gpt2_config_file
,
pytorch_dump_folder_path
):
def
convert_gpt2_checkpoint_to_pytorch
(
gpt2_checkpoint_path
,
gpt2_config_file
,
pytorch_dump_folder_path
):
# Construct model
# Construct model
...
@@ -36,7 +39,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
...
@@ -36,7 +39,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
model
=
GPT2Model
(
config
)
model
=
GPT2Model
(
config
)
# Load weights from numpy
# Load weights from numpy
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
)
load_tf_weights_in_gpt2
(
model
,
config
,
gpt2_checkpoint_path
)
# Save pytorch-model
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
...
...
pytorch_transformers/convert_openai_checkpoint_to_pytorch.py
View file @
1b35d05d
...
@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
...
@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
OpenAIGPTModel
,
OpenAIGPTModel
,
load_tf_weights_in_openai_gpt
)
load_tf_weights_in_openai_gpt
)
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_openai_checkpoint_to_pytorch
(
openai_checkpoint_folder_path
,
openai_config_file
,
pytorch_dump_folder_path
):
def
convert_openai_checkpoint_to_pytorch
(
openai_checkpoint_folder_path
,
openai_config_file
,
pytorch_dump_folder_path
):
# Construct model
# Construct model
...
@@ -36,7 +39,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
...
@@ -36,7 +39,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
model
=
OpenAIGPTModel
(
config
)
model
=
OpenAIGPTModel
(
config
)
# Load weights from numpy
# Load weights from numpy
load_tf_weights_in_openai_gpt
(
model
,
openai_checkpoint_folder_path
)
load_tf_weights_in_openai_gpt
(
model
,
config
,
openai_checkpoint_folder_path
)
# Save pytorch-model
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
...
...
pytorch_transformers/convert_tf_checkpoint_to_pytorch.py
View file @
1b35d05d
...
@@ -18,15 +18,14 @@ from __future__ import absolute_import
...
@@ -18,15 +18,14 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
re
import
argparse
import
argparse
import
tensorflow
as
tf
import
torch
import
torch
import
numpy
as
np
from
pytorch_transformers.modeling_bert
import
BertConfig
,
BertForPreTraining
,
load_tf_weights_in_bert
from
pytorch_transformers.modeling_bert
import
BertConfig
,
BertForPreTraining
,
load_tf_weights_in_bert
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
bert_config_file
)
config
=
BertConfig
.
from_json_file
(
bert_config_file
)
...
@@ -34,7 +33,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
...
@@ -34,7 +33,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
model
=
BertForPreTraining
(
config
)
model
=
BertForPreTraining
(
config
)
# Load weights from tf checkpoint
# Load weights from tf checkpoint
load_tf_weights_in_bert
(
model
,
tf_checkpoint_path
)
load_tf_weights_in_bert
(
model
,
config
,
tf_checkpoint_path
)
# Save pytorch-model
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
...
...
pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
1b35d05d
...
@@ -36,6 +36,9 @@ if sys.version_info[0] == 2:
...
@@ -36,6 +36,9 @@ if sys.version_info[0] == 2:
else
:
else
:
import
pickle
import
pickle
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
# We do this to be able to load python 2 datasets pickles
# We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils
.
Vocab
=
data_utils
.
TransfoXLTokenizer
data_utils
.
Vocab
=
data_utils
.
TransfoXLTokenizer
...
...
pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py
View file @
1b35d05d
...
@@ -24,9 +24,10 @@ import torch
...
@@ -24,9 +24,10 @@ import torch
import
numpy
import
numpy
from
pytorch_transformers.modeling_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
from
pytorch_transformers.modeling_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
from
pytorch_transformers.modeling_xlm
import
(
XLMConfig
,
XLMModel
)
from
pytorch_transformers.tokenization_xlm
import
VOCAB_FILES_NAMES
from
pytorch_transformers.tokenization_xlm
import
VOCAB_FILES_NAMES
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_xlm_checkpoint_to_pytorch
(
xlm_checkpoint_path
,
pytorch_dump_folder_path
):
def
convert_xlm_checkpoint_to_pytorch
(
xlm_checkpoint_path
,
pytorch_dump_folder_path
):
# Load checkpoint
# Load checkpoint
...
...
pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py
View file @
1b35d05d
...
@@ -40,6 +40,8 @@ GLUE_TASKS_NUM_LABELS = {
...
@@ -40,6 +40,8 @@ GLUE_TASKS_NUM_LABELS = {
"wnli"
:
2
,
"wnli"
:
2
,
}
}
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_folder_path
,
finetuning_task
=
None
):
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_folder_path
,
finetuning_task
=
None
):
# Initialise PyTorch model
# Initialise PyTorch model
...
@@ -48,14 +50,17 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
...
@@ -48,14 +50,17 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
finetuning_task
=
finetuning_task
.
lower
()
if
finetuning_task
is
not
None
else
""
finetuning_task
=
finetuning_task
.
lower
()
if
finetuning_task
is
not
None
else
""
if
finetuning_task
in
GLUE_TASKS_NUM_LABELS
:
if
finetuning_task
in
GLUE_TASKS_NUM_LABELS
:
print
(
"Building PyTorch XLNetForSequenceClassification model from configuration: {}"
.
format
(
str
(
config
)))
print
(
"Building PyTorch XLNetForSequenceClassification model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetForSequenceClassification
(
config
,
num_labels
=
GLUE_TASKS_NUM_LABELS
[
finetuning_task
])
config
.
finetuning_task
=
finetuning_task
config
.
num_labels
=
GLUE_TASKS_NUM_LABELS
[
finetuning_task
]
model
=
XLNetForSequenceClassification
(
config
)
elif
'squad'
in
finetuning_task
:
elif
'squad'
in
finetuning_task
:
config
.
finetuning_task
=
finetuning_task
model
=
XLNetForQuestionAnswering
(
config
)
model
=
XLNetForQuestionAnswering
(
config
)
else
:
else
:
model
=
XLNetLMHeadModel
(
config
)
model
=
XLNetLMHeadModel
(
config
)
# Load weights from tf checkpoint
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
config
,
tf_checkpoint_path
,
finetuning_task
)
load_tf_weights_in_xlnet
(
model
,
config
,
tf_checkpoint_path
)
# Save pytorch-model
# Save pytorch-model
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
...
...
pytorch_transformers/modeling_xlnet.py
View file @
1b35d05d
...
@@ -37,9 +37,11 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
...
@@ -37,9 +37,11 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
'xlnet-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin"
,
}
}
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'xlnet-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json"
,
}
}
...
...
pytorch_transformers/tokenization_transfo_xl.py
View file @
1b35d05d
...
@@ -50,7 +50,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
...
@@ -50,7 +50,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'transfo-xl-wt103'
:
512
,
'transfo-xl-wt103'
:
None
,
}
}
PRETRAINED_CORPUS_ARCHIVE_MAP
=
{
PRETRAINED_CORPUS_ARCHIVE_MAP
=
{
...
...
pytorch_transformers/tokenization_utils.py
View file @
1b35d05d
...
@@ -208,7 +208,8 @@ class PreTrainedTokenizer(object):
...
@@ -208,7 +208,8 @@ class PreTrainedTokenizer(object):
# if we're using a pretrained model, ensure the tokenizer
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
# wont index sequences longer than the number of positional embeddings
max_len
=
cls
.
max_model_input_sizes
[
pretrained_model_name_or_path
]
max_len
=
cls
.
max_model_input_sizes
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
if
max_len
is
not
None
and
isinstance
(
max_len
,
(
int
,
float
)):
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Merge resolved_vocab_files arguments in kwargs.
# Merge resolved_vocab_files arguments in kwargs.
added_tokens_file
=
resolved_vocab_files
.
pop
(
'added_tokens_file'
,
None
)
added_tokens_file
=
resolved_vocab_files
.
pop
(
'added_tokens_file'
,
None
)
...
...
pytorch_transformers/tokenization_xlnet.py
View file @
1b35d05d
...
@@ -32,12 +32,14 @@ VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
...
@@ -32,12 +32,14 @@ VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
PRETRAINED_VOCAB_FILES_MAP
=
{
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
'vocab_file'
:
{
{
'xlnet-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model"
,
}
}
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'xlnet-large-cased'
:
512
,
'xlnet-base-cased'
:
None
,
'xlnet-large-cased'
:
None
,
}
}
SPIECE_UNDERLINE
=
u
'▁'
SPIECE_UNDERLINE
=
u
'▁'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment