Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
603c513b
Commit
603c513b
authored
Jun 25, 2019
by
thomwolf
Browse files
update main conversion script and readme
parent
7de17404
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
48 deletions
+95
-48
README.md
README.md
+20
-4
pytorch_pretrained_bert/__main__.py
pytorch_pretrained_bert/__main__.py
+71
-40
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+4
-4
No files found.
README.md
View file @
603c513b
...
...
@@ -1690,7 +1690,7 @@ Here is an example of the conversion process for a pre-trained `BERT-Base Uncase
```
shell
export
BERT_BASE_DIR
=
/path/to/bert/uncased_L-12_H-768_A-12
pytorch_pretrained_bert
convert_tf_checkpoint_to_pytorch
\
pytorch_pretrained_bert
bert
\
$BERT_BASE_DIR
/bert_model.ckpt
\
$BERT_BASE_DIR
/bert_config.json
\
$BERT_BASE_DIR
/pytorch_model.bin
...
...
@@ -1705,7 +1705,7 @@ Here is an example of the conversion process for a pre-trained OpenAI GPT model,
```
shell
export
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
=
/path/to/openai/pretrained/numpy/weights
pytorch_pretrained_bert
convert_openai_checkpoin
t
\
pytorch_pretrained_bert
gp
t
\
$OPENAI_GPT_CHECKPOINT_FOLDER_PATH
\
$PYTORCH_DUMP_OUTPUT
\
[
OPENAI_GPT_CONFIG]
...
...
@@ -1718,7 +1718,7 @@ Here is an example of the conversion process for a pre-trained Transformer-XL mo
```
shell
export
TRANSFO_XL_CHECKPOINT_FOLDER_PATH
=
/path/to/transfo/xl/checkpoint
pytorch_pretrained_bert
convert_
transfo_xl
_checkpoint
\
pytorch_pretrained_bert transfo_xl
\
$TRANSFO_XL_CHECKPOINT_FOLDER_PATH
\
$PYTORCH_DUMP_OUTPUT
\
[
TRANSFO_XL_CONFIG]
...
...
@@ -1731,12 +1731,28 @@ Here is an example of the conversion process for a pre-trained OpenAI's GPT-2 mo
```
shell
export
GPT2_DIR
=
/path/to/gpt2/checkpoint
pytorch_pretrained_bert
convert_gpt2_checkpoint
\
pytorch_pretrained_bert
gpt2
\
$GPT2_DIR
/model.ckpt
\
$PYTORCH_DUMP_OUTPUT
\
[
GPT2_CONFIG]
```
### XLNet
Here is an example of the conversion process for a pre-trained XLNet model, fine-tuned on STS-B using the TensorFlow script:
```
shell
export
TRANSFO_XL_CHECKPOINT_PATH
=
/path/to/xlnet/checkpoint
export
TRANSFO_XL_CONFIG_PATH
=
/path/to/xlnet/config
pytorch_pretrained_bert xlnet
\
$TRANSFO_XL_CHECKPOINT_PATH
\
$TRANSFO_XL_CONFIG_PATH
\
$PYTORCH_DUMP_OUTPUT
\
STS-B
\
```
## TPU
TPU support and pretraining scripts
...
...
pytorch_pretrained_bert/__main__.py
View file @
603c513b
# coding: utf8
def
main
():
import
sys
if
(
len
(
sys
.
argv
)
!=
4
and
len
(
sys
.
argv
)
!=
5
)
or
sys
.
argv
[
1
]
not
in
[
"convert_tf_checkpoint_to_pytorch"
,
"convert_openai_checkpoint"
,
"convert_transfo_xl_checkpoint"
,
"convert_gpt2_checkpoint"
,
]:
if
(
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
6
)
or
sys
.
argv
[
1
]
not
in
[
"bert"
,
"gpt"
,
"transfo_xl"
,
"gpt2"
,
"xlnet"
]:
print
(
"Should be used as one of:
\n
"
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`,
\n
"
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`,
\n
"
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`"
)
">> `pytorch_pretrained_bert bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`,
\n
"
">> `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`,
\n
"
">> `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`"
)
else
:
if
sys
.
argv
[
1
]
==
"
convert_tf_checkpoint_to_pytorch
"
:
if
sys
.
argv
[
1
]
==
"
bert
"
:
try
:
from
.convert_tf_checkpoint_to_pytorch
import
convert_tf_checkpoint_to_pytorch
except
ImportError
:
...
...
@@ -25,24 +21,28 @@ def main():
if
len
(
sys
.
argv
)
!=
5
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert
convert_tf_checkpoint_to_pytorch
TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`"
)
print
(
"Should be used as `pytorch_pretrained_bert
bert
TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`"
)
else
:
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
.
pop
()
TF_CONFIG
=
sys
.
argv
.
pop
()
TF_CHECKPOINT
=
sys
.
argv
.
pop
()
convert_tf_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
elif
sys
.
argv
[
1
]
==
"
convert_openai_checkpoin
t"
:
elif
sys
.
argv
[
1
]
==
"
gp
t"
:
from
.convert_openai_checkpoint_to_pytorch
import
convert_openai_checkpoint_to_pytorch
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
=
sys
.
argv
[
2
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
OPENAI_GPT_CONFIG
=
sys
.
argv
[
4
]
if
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
5
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`"
)
else
:
OPENAI_GPT_CONFIG
=
""
convert_openai_checkpoint_to_pytorch
(
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
,
OPENAI_GPT_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
elif
sys
.
argv
[
1
]
==
"convert_transfo_xl_checkpoint"
:
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
=
sys
.
argv
[
2
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
OPENAI_GPT_CONFIG
=
sys
.
argv
[
4
]
else
:
OPENAI_GPT_CONFIG
=
""
convert_openai_checkpoint_to_pytorch
(
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
,
OPENAI_GPT_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
elif
sys
.
argv
[
1
]
==
"transfo_xl"
:
try
:
from
.convert_transfo_xl_checkpoint_to_pytorch
import
convert_transfo_xl_checkpoint_to_pytorch
except
ImportError
:
...
...
@@ -50,34 +50,65 @@ def main():
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
if
'ckpt'
in
sys
.
argv
[
2
].
lower
():
TF_CHECKPOINT
=
sys
.
argv
[
2
]
TF_DATASET_FILE
=
""
if
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
5
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`"
)
else
:
TF_DATASET_FILE
=
sys
.
argv
[
2
]
TF_CHECKPOINT
=
""
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
TF_CONFIG
=
sys
.
argv
[
4
]
if
'ckpt'
in
sys
.
argv
[
2
].
lower
():
TF_CHECKPOINT
=
sys
.
argv
[
2
]
TF_DATASET_FILE
=
""
else
:
TF_DATASET_FILE
=
sys
.
argv
[
2
]
TF_CHECKPOINT
=
""
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
TF_CONFIG
=
sys
.
argv
[
4
]
else
:
TF_CONFIG
=
""
convert_transfo_xl_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
TF_DATASET_FILE
)
elif
sys
.
argv
[
1
]
==
"gpt2"
:
try
:
from
.convert_gpt2_checkpoint_to_pytorch
import
convert_gpt2_checkpoint_to_pytorch
except
ImportError
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
if
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
5
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`"
)
else
:
TF_CONFIG
=
""
convert_transfo_xl_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
TF_DATASET_FILE
)
TF_CHECKPOINT
=
sys
.
argv
[
2
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
TF_CONFIG
=
sys
.
argv
[
4
]
else
:
TF_CONFIG
=
""
convert_gpt2_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
else
:
try
:
from
.convert_
gpt2
_checkpoint_to_pytorch
import
convert_
gpt2
_checkpoint_to_pytorch
from
.convert_
xlnet
_checkpoint_to_pytorch
import
convert_
xlnet
_checkpoint_to_pytorch
except
ImportError
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
TF_CHECKPOINT
=
sys
.
argv
[
2
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
TF_CONFIG
=
sys
.
argv
[
4
]
if
len
(
sys
.
argv
)
<
5
or
len
(
sys
.
argv
)
>
6
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`"
)
else
:
TF_CONFIG
=
""
convert_gpt2_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
TF_CHECKPOINT
=
sys
.
argv
[
2
]
TF_CONFIG
=
sys
.
argv
[
3
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
4
]
if
len
(
sys
.
argv
)
==
6
:
FINETUNING_TASK
=
sys
.
argv
[
5
]
convert_xlnet_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
FINETUNING_TASK
)
if
__name__
==
'__main__'
:
main
()
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
603c513b
...
...
@@ -70,7 +70,7 @@ if __name__ == "__main__":
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
finetuning_task
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
...
...
@@ -81,6 +81,6 @@ if __name__ == "__main__":
help
=
"Name of a task on which the XLNet TensorFloaw model was fine-tuned"
)
args
=
parser
.
parse_args
()
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_folder_path
,
args
.
finetuning_task
)
args
.
xlnet_config_file
,
args
.
pytorch_dump_folder_path
,
args
.
finetuning_task
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment