Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6f152572
Commit
6f152572
authored
Sep 05, 2019
by
thomwolf
Browse files
add conversion script, rename conversion scripts
parent
a4704b12
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
0 deletions
+65
-0
pytorch_transformers/convert_bert_original_tf_checkpoint_to_pytorch.py
...formers/convert_bert_original_tf_checkpoint_to_pytorch.py
+0
-0
pytorch_transformers/convert_bert_pytorch_checkpoint_to_original_tf.py
...formers/convert_bert_pytorch_checkpoint_to_original_tf.py
+0
-0
pytorch_transformers/convert_bert_pytorch_checkpoint_to_tf.py
...rch_transformers/convert_bert_pytorch_checkpoint_to_tf.py
+65
-0
No files found.
pytorch_transformers/convert_tf_checkpoint_to_pytorch.py
→
pytorch_transformers/convert_
bert_original_
tf_checkpoint_to_pytorch.py
View file @
6f152572
File moved
pytorch_transformers/convert_pytorch_checkpoint_to_tf.py
→
pytorch_transformers/convert_
bert_
pytorch_checkpoint_to_
original_
tf.py
View file @
6f152572
File moved
pytorch_transformers/convert_bert_pytorch_checkpoint_to_tf.py
0 → 100644
View file @
6f152572
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
tensorflow
as
tf
from
pytorch_transformers
import
BertConfig
,
TFBertForPreTraining
,
load_pt_weights_in_bert
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_bert_checkpoint_to_tf
(
pytorch_checkpoint_path
,
bert_config_file
,
tf_dump_path
):
# Initialise TF model
config
=
BertConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building TensorFlow model from configuration: {}"
.
format
(
str
(
config
)))
model
=
TFBertForPreTraining
(
config
)
# Load weights from tf checkpoint
model
=
load_pt_weights_in_bert
(
model
,
config
,
pytorch_checkpoint_path
)
# Save pytorch-model
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
model
.
save_weights
(
tf_dump_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--pytorch_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the PyTorch checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--tf_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output Tensorflow dump file."
)
args
=
parser
.
parse_args
()
convert_bert_checkpoint_to_tf
(
args
.
pytorch_checkpoint_path
,
args
.
bert_config_file
,
args
.
tf_dump_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment