Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5951d860
Commit
5951d860
authored
Sep 05, 2019
by
thomwolf
Browse files
add conversion script, rename conversion scripts
parent
aa4c8804
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
0 deletions
+65
-0
pytorch_transformers/convert_bert_original_tf_checkpoint_to_pytorch.py
...formers/convert_bert_original_tf_checkpoint_to_pytorch.py
+0
-0
pytorch_transformers/convert_bert_pytorch_checkpoint_to_original_tf.py
...formers/convert_bert_pytorch_checkpoint_to_original_tf.py
+0
-0
pytorch_transformers/convert_bert_pytorch_checkpoint_to_tf.py
...rch_transformers/convert_bert_pytorch_checkpoint_to_tf.py
+65
-0
No files found.
pytorch_transformers/convert_tf_checkpoint_to_pytorch.py
→
pytorch_transformers/convert_
bert_original_
tf_checkpoint_to_pytorch.py
View file @
5951d860
File moved
pytorch_transformers/convert_pytorch_checkpoint_to_tf.py
→
pytorch_transformers/convert_
bert_
pytorch_checkpoint_to_
original_
tf.py
View file @
5951d860
File moved
pytorch_transformers/convert_bert_pytorch_checkpoint_to_tf.py
0 → 100644
View file @
5951d860
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
tensorflow
as
tf
from
pytorch_transformers
import
BertConfig
,
TFBertForPreTraining
,
load_pt_weights_in_bert
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_bert_checkpoint_to_tf
(
pytorch_checkpoint_path
,
bert_config_file
,
tf_dump_path
):
# Initialise TF model
config
=
BertConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building TensorFlow model from configuration: {}"
.
format
(
str
(
config
)))
model
=
TFBertForPreTraining
(
config
)
# Load weights from tf checkpoint
model
=
load_pt_weights_in_bert
(
model
,
config
,
pytorch_checkpoint_path
)
# Save pytorch-model
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
model
.
save_weights
(
tf_dump_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--pytorch_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the PyTorch checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--tf_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output Tensorflow dump file."
)
args
=
parser
.
parse_args
()
convert_bert_checkpoint_to_tf
(
args
.
pytorch_checkpoint_path
,
args
.
bert_config_file
,
args
.
tf_dump_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment