Commit f1433db4 authored by Chris's avatar Chris
Browse files

update to hf->tf args

parent 077a5b0d
...@@ -18,16 +18,18 @@ ...@@ -18,16 +18,18 @@
import os import os
import argparse import argparse
import numpy as np import numpy as np
import tensorflow as tf
from pytorch_pretrained_bert.modeling import BertConfig, BertModel from pytorch_pretrained_bert.modeling import BertConfig, BertModel
def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str): def convert_hf_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str):
""" """
:param model:BertModel Pytorch model instance to be converted :param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: directory to save Tensorflow model :param ckpt_dir: directory to save Tensorflow model
:return:
Supported HF models: Currently supported HF models:
Y BertModel Y BertModel
N BertForMaskedLM N BertForMaskedLM
N BertForPreTraining N BertForPreTraining
...@@ -35,20 +37,13 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str): ...@@ -35,20 +37,13 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
N BertForNextSentencePrediction N BertForNextSentencePrediction
N BertForSequenceClassification N BertForSequenceClassification
N BertForQuestionAnswering N BertForQuestionAnswering
Note:
To keep tf out of package-level requirements, it's imported locally.
""" """
import tensorflow as tf
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir) os.makedirs(ckpt_dir)
session = tf.Session() session = tf.Session()
state_dict = model.state_dict() state_dict = model.state_dict()
tf_vars = [] tf_vars = []
def to_tf_var_name(name:str): def to_tf_var_name(name:str):
...@@ -61,6 +56,7 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str): ...@@ -61,6 +56,7 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
name = name.replace('LayerNorm/weight', 'LayerNorm/gamma') name = name.replace('LayerNorm/weight', 'LayerNorm/gamma')
name = name.replace('LayerNorm/bias', 'LayerNorm/beta') name = name.replace('LayerNorm/bias', 'LayerNorm/beta')
name = name.replace('weight', 'kernel') name = name.replace('weight', 'kernel')
# name += ':0'
return 'bert/{}'.format(name) return 'bert/{}'.format(name)
def assign_tf_var(tensor:np.ndarray, name:str): def assign_tf_var(tensor:np.ndarray, name:str):
...@@ -81,44 +77,35 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str): ...@@ -81,44 +77,35 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name)))) print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name))))
saver = tf.train.Saver(tf_vars) saver = tf.train.Saver(tf_vars)
saver.save(session, os.path.join(ckpt_dir, 'model')) saver.save(session, os.path.join(ckpt_dir, args.pytorch_model_name))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_name_or_path", parser.add_argument("--pytorch_model_dir",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="pretrained_model_name_or_path: either: \ help="Directory containing pytorch model")
- a str with the name of a pre-trained model to load selected in the list of: \ parser.add_argument("--pytorch_model_name",
. `bert-base-uncased` \
. `bert-large-uncased` \
. `bert-base-cased` \
. `bert-large-cased` \
. `bert-base-multilingual-uncased` \
. `bert-base-multilingual-cased` \
. `bert-base-chinese` \
- a path or url to a pretrained model archive containing: \
. `bert_config.json` a configuration file for the model \
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance \
- a path or url to a pretrained model archive containing: \
. `bert_config.json` a configuration file for the model \
. `model.ckpt` a TensorFlow checkpoint")
parser.add_argument("--config_file_path",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to bert config file.") help="model name (e.g. bert-base-uncased)")
parser.add_argument("--cache_dir", parser.add_argument("--config_file_path",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to a folder in which the TF model will be cached.") help="Path to bert config file")
parser.add_argument("--tf_checkpoint_dir",
default="",
type=str,
required=True,
help="Directory in which to save tensorflow model")
args = parser.parse_args() args = parser.parse_args()
model = BertModel( model = BertModel(
config=BertConfig(args.config_file_path) config=BertConfig(args.config_file_path)
).from_pretrained(args.pretrained_model_name_or_path) ).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir)
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.cache_dir) \ No newline at end of file
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment