Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
96c2b77f
Commit
96c2b77f
authored
May 02, 2019
by
Chris
Browse files
added file to convert pytorch->tf
parent
3ae8c8be
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
153 additions
and
0 deletions
+153
-0
pytorch_pretrained_bert/convert_hf_checkpoint_to_tf.py
pytorch_pretrained_bert/convert_hf_checkpoint_to_tf.py
+153
-0
No files found.
pytorch_pretrained_bert/convert_hf_checkpoint_to_tf.py
0 → 100644
View file @
96c2b77f
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
import
os
import
argparse
import
numpy
as
np
from
pytorch_pretrained_bert.modeling
import
BertConfig
,
BertModel
# def __get_var_names(config):
#
# models = {
# 'BertModel': BertModel(config),
# 'BertForMaskedLM': BertForMaskedLM(config),
# 'BertForPreTraining': BertForPreTraining(config),
# 'BertForMultipleChoice': BertForMultipleChoice(config, num_choices=100),
# 'BertForNextSentencePrediction': BertForNextSentencePrediction(config),
# 'BertForSequenceClassification': BertForSequenceClassification(config, num_labels=100),
# 'BertForQuestionAnswering': BertForQuestionAnswering(config)
# }
#
# for name, model in models.items():
# state_dict = model.state_dict()
# torch_vars = []
# for var_ in state_dict:
# torch_vars.append(var_ + ', ' + str(tuple(state_dict[var_].shape)))
# json.dump(torch_vars, fp=open('torch_var_names_{}.json'.format(name), 'w'), indent=3)
def
convert_hf_checkpoint_to_tf
(
model
:
BertModel
,
ckpt_dir
:
str
):
"""
:param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: directory to save Tensorflow model
Supported HF models:
Y BertModel
N BertForMaskedLM
N BertForPreTraining
N BertForMultipleChoice
N BertForNextSentencePrediction
N BertForSequenceClassification
N BertForQuestionAnswering
Note:
TF isn't & shouldn't be a package-level requirement; this
feature is requested enough to warrant a local import.
"""
import
tensorflow
as
tf
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
session
=
tf
.
Session
()
state_dict
=
model
.
state_dict
()
tf_vars
=
[]
def
to_tf_var_name
(
name
:
str
):
"""todo: compile as regex"""
name
=
name
.
replace
(
'layer.'
,
'layer_'
)
name
=
name
.
replace
(
'word_embeddings.weight'
,
'word_embeddings'
)
name
=
name
.
replace
(
'position_embeddings.weight'
,
'position_embeddings'
)
name
=
name
.
replace
(
'token_type_embeddings.weight'
,
'token_type_embeddings'
)
name
=
name
.
replace
(
'.'
,
'/'
)
name
=
name
.
replace
(
'LayerNorm/weight'
,
'LayerNorm/gamma'
)
name
=
name
.
replace
(
'LayerNorm/bias'
,
'LayerNorm/beta'
)
name
=
name
.
replace
(
'weight'
,
'kernel'
)
return
'bert/{}'
.
format
(
name
)
def
assign_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
):
tmp_var
=
tf
.
Variable
(
initial_value
=
tensor
)
tf_var
=
tf
.
get_variable
(
dtype
=
tmp_var
.
dtype
,
shape
=
tmp_var
.
shape
,
name
=
name
)
op
=
tf
.
assign
(
ref
=
tf_var
,
value
=
tmp_var
)
session
.
run
(
tf
.
variables_initializer
([
tmp_var
,
tf_var
]))
session
.
run
(
fetches
=
[
op
,
tf_var
])
return
tf_var
for
var_name
in
state_dict
:
tf_name
=
to_tf_var_name
(
var_name
)
torch_tensor
=
state_dict
[
var_name
].
numpy
()
if
var_name
.
endswith
(
'dense.weight'
):
torch_tensor
=
torch_tensor
.
T
tf_tensor
=
assign_tf_var
(
tensor
=
torch_tensor
,
name
=
tf_name
)
tf_vars
.
append
(
tf_tensor
)
print
(
"{0}{1}initialized"
.
format
(
tf_name
,
" "
*
(
60
-
len
(
tf_name
))))
saver
=
tf
.
train
.
Saver
(
tf_vars
)
saver
.
save
(
session
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--pretrained_model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"pretrained_model_name_or_path: either:
\
- a str with the name of a pre-trained model to load selected in the list of:
\
. `bert-base-uncased`
\
. `bert-large-uncased`
\
. `bert-base-cased`
\
. `bert-large-cased`
\
. `bert-base-multilingual-uncased`
\
. `bert-base-multilingual-cased`
\
. `bert-base-chinese`
\
- a path or url to a pretrained model archive containing:
\
. `bert_config.json` a configuration file for the model
\
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
\
- a path or url to a pretrained model archive containing:
\
. `bert_config.json` a configuration file for the model
\
. `model.ckpt` a TensorFlow checkpoint"
)
parser
.
add_argument
(
"--config_file_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to bert config file."
)
parser
.
add_argument
(
"--cache_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"path to a folder in which the TF model will be cached."
)
args
=
parser
.
parse_args
()
model
=
BertModel
(
config
=
BertConfig
(
args
.
config_file_path
)
).
from_pretrained
(
args
.
pretrained_model_name_or_path
)
convert_hf_checkpoint_to_tf
(
model
=
model
,
ckpt_dir
=
args
.
cache_
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment