Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
603c513b
Commit
603c513b
authored
Jun 25, 2019
by
thomwolf
Browse files
update main conversion script and readme
parent
7de17404
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
48 deletions
+95
-48
README.md
README.md
+20
-4
pytorch_pretrained_bert/__main__.py
pytorch_pretrained_bert/__main__.py
+71
-40
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+4
-4
No files found.
README.md
View file @
603c513b
...
...
@@ -1690,7 +1690,7 @@ Here is an example of the conversion process for a pre-trained `BERT-Base Uncase
```
shell
export
BERT_BASE_DIR
=
/path/to/bert/uncased_L-12_H-768_A-12
pytorch_pretrained_bert
convert_tf_checkpoint_to_pytorch
\
pytorch_pretrained_bert
bert
\
$BERT_BASE_DIR
/bert_model.ckpt
\
$BERT_BASE_DIR
/bert_config.json
\
$BERT_BASE_DIR
/pytorch_model.bin
...
...
@@ -1705,7 +1705,7 @@ Here is an example of the conversion process for a pre-trained OpenAI GPT model,
```
shell
export
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
=
/path/to/openai/pretrained/numpy/weights
pytorch_pretrained_bert
convert_openai_checkpoin
t
\
pytorch_pretrained_bert
gp
t
\
$OPENAI_GPT_CHECKPOINT_FOLDER_PATH
\
$PYTORCH_DUMP_OUTPUT
\
[
OPENAI_GPT_CONFIG]
...
...
@@ -1718,7 +1718,7 @@ Here is an example of the conversion process for a pre-trained Transformer-XL mo
```
shell
export
TRANSFO_XL_CHECKPOINT_FOLDER_PATH
=
/path/to/transfo/xl/checkpoint
pytorch_pretrained_bert
convert_
transfo_xl
_checkpoint
\
pytorch_pretrained_bert transfo_xl
\
$TRANSFO_XL_CHECKPOINT_FOLDER_PATH
\
$PYTORCH_DUMP_OUTPUT
\
[
TRANSFO_XL_CONFIG]
...
...
@@ -1731,12 +1731,28 @@ Here is an example of the conversion process for a pre-trained OpenAI's GPT-2 mo
```
shell
export
GPT2_DIR
=
/path/to/gpt2/checkpoint
pytorch_pretrained_bert
convert_gpt2_checkpoint
\
pytorch_pretrained_bert
gpt2
\
$GPT2_DIR
/model.ckpt
\
$PYTORCH_DUMP_OUTPUT
\
[
GPT2_CONFIG]
```
### XLNet
Here is an example of the conversion process for a pre-trained XLNet model, fine-tuned on STS-B using the TensorFlow script:
```
shell
export
TRANSFO_XL_CHECKPOINT_PATH
=
/path/to/xlnet/checkpoint
export
TRANSFO_XL_CONFIG_PATH
=
/path/to/xlnet/config
pytorch_pretrained_bert xlnet
\
$TRANSFO_XL_CHECKPOINT_PATH
\
$TRANSFO_XL_CONFIG_PATH
\
$PYTORCH_DUMP_OUTPUT
\
STS-B
\
```
## TPU
TPU support and pretraining scripts
...
...
pytorch_pretrained_bert/__main__.py
View file @
603c513b
# coding: utf8
def
main
():
import
sys
if
(
len
(
sys
.
argv
)
!=
4
and
len
(
sys
.
argv
)
!=
5
)
or
sys
.
argv
[
1
]
not
in
[
"convert_tf_checkpoint_to_pytorch"
,
"convert_openai_checkpoint"
,
"convert_transfo_xl_checkpoint"
,
"convert_gpt2_checkpoint"
,
]:
if
(
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
6
)
or
sys
.
argv
[
1
]
not
in
[
"bert"
,
"gpt"
,
"transfo_xl"
,
"gpt2"
,
"xlnet"
]:
print
(
"Should be used as one of:
\n
"
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`,
\n
"
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`,
\n
"
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`"
)
">> `pytorch_pretrained_bert bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`,
\n
"
">> `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`,
\n
"
">> `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`"
)
else
:
if
sys
.
argv
[
1
]
==
"
convert_tf_checkpoint_to_pytorch
"
:
if
sys
.
argv
[
1
]
==
"
bert
"
:
try
:
from
.convert_tf_checkpoint_to_pytorch
import
convert_tf_checkpoint_to_pytorch
except
ImportError
:
...
...
@@ -25,14 +21,18 @@ def main():
if
len
(
sys
.
argv
)
!=
5
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert
convert_tf_checkpoint_to_pytorch
TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`"
)
print
(
"Should be used as `pytorch_pretrained_bert
bert
TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`"
)
else
:
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
.
pop
()
TF_CONFIG
=
sys
.
argv
.
pop
()
TF_CHECKPOINT
=
sys
.
argv
.
pop
()
convert_tf_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
elif
sys
.
argv
[
1
]
==
"
convert_openai_checkpoin
t"
:
elif
sys
.
argv
[
1
]
==
"
gp
t"
:
from
.convert_openai_checkpoint_to_pytorch
import
convert_openai_checkpoint_to_pytorch
if
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
5
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`"
)
else
:
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
=
sys
.
argv
[
2
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
...
...
@@ -42,7 +42,7 @@ def main():
convert_openai_checkpoint_to_pytorch
(
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
,
OPENAI_GPT_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
elif
sys
.
argv
[
1
]
==
"
convert_
transfo_xl
_checkpoint
"
:
elif
sys
.
argv
[
1
]
==
"transfo_xl"
:
try
:
from
.convert_transfo_xl_checkpoint_to_pytorch
import
convert_transfo_xl_checkpoint_to_pytorch
except
ImportError
:
...
...
@@ -50,7 +50,10 @@ def main():
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
if
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
5
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`"
)
else
:
if
'ckpt'
in
sys
.
argv
[
2
].
lower
():
TF_CHECKPOINT
=
sys
.
argv
[
2
]
TF_DATASET_FILE
=
""
...
...
@@ -63,7 +66,7 @@ def main():
else
:
TF_CONFIG
=
""
convert_transfo_xl_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
TF_DATASET_FILE
)
el
se
:
el
if
sys
.
argv
[
1
]
==
"gpt2"
:
try
:
from
.convert_gpt2_checkpoint_to_pytorch
import
convert_gpt2_checkpoint_to_pytorch
except
ImportError
:
...
...
@@ -72,6 +75,10 @@ def main():
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
if
len
(
sys
.
argv
)
<
4
or
len
(
sys
.
argv
)
>
5
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`"
)
else
:
TF_CHECKPOINT
=
sys
.
argv
[
2
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
...
...
@@ -79,5 +86,29 @@ def main():
else
:
TF_CONFIG
=
""
convert_gpt2_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
else
:
try
:
from
.convert_xlnet_checkpoint_to_pytorch
import
convert_xlnet_checkpoint_to_pytorch
except
ImportError
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
if
len
(
sys
.
argv
)
<
5
or
len
(
sys
.
argv
)
>
6
:
# pylint: disable=line-too-long
print
(
"Should be used as `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`"
)
else
:
TF_CHECKPOINT
=
sys
.
argv
[
2
]
TF_CONFIG
=
sys
.
argv
[
3
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
4
]
if
len
(
sys
.
argv
)
==
6
:
FINETUNING_TASK
=
sys
.
argv
[
5
]
convert_xlnet_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
FINETUNING_TASK
)
if
__name__
==
'__main__'
:
main
()
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
603c513b
...
...
@@ -70,7 +70,7 @@ if __name__ == "__main__":
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
finetuning_task
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment