"doc/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "b9d5a3e756c5c4fd7e3c99f173b92de7e942949c"
Commit fa0c5a2e authored by lukovnikov's avatar lukovnikov
Browse files

clean up pr

parent f4d79f44
......@@ -26,14 +26,35 @@ import numpy as np
from modeling import BertConfig, BertModel
parser = argparse.ArgumentParser()
def convert(config_path, ckpt_path, out_path=None):
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
def convert():
# Initialise PyTorch model
config = BertConfig.from_json_file(config_path)
config = BertConfig.from_json_file(args.bert_config_file)
model = BertModel(config)
# Load weights from TF model
path = ckpt_path
path = args.tf_checkpoint_path
print("Converting TensorFlow checkpoint from {}".format(path))
init_vars = tf.train.list_variables(path)
......@@ -47,17 +68,11 @@ def convert(config_path, ckpt_path, out_path=None):
arrays.append(array)
for name, array in zip(names, arrays):
if not name.startswith("bert"):
print("Skipping {}".format(name))
continue
else:
name = name.replace("bert/", "") # skip "bert/"
name = name[5:] # skip "bert/"
print("Loading {}".format(name))
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m":
print("Skipping {}".format("/".join(name)))
if name[0] in ['redictions', 'eq_relationship']:
print("Skipping")
continue
pointer = model
for m_name in name:
......@@ -84,32 +99,7 @@ def convert(config_path, ckpt_path, out_path=None):
pointer.data = torch.from_numpy(array)
# Save pytorch-model
if out_path is not None:
torch.save(model.state_dict(), out_path)
return model
torch.save(model.state_dict(), args.pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default=None,
type=str,
required=False,
help="Path to the output PyTorch model.")
args = parser.parse_args()
print(args)
convert(args.bert_config_file, args.tf_checkpoint_path, args.pytorch_dump_path)
convert()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment