Unverified Commit 0e918707 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #907 from dhpollack/fix_convert_to_tf

Fix convert to tf
parents 44dd941e c90119e5
...@@ -41,7 +41,7 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s ...@@ -41,7 +41,7 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
N BertForQuestionAnswering N BertForQuestionAnswering
""" """
tensors_to_transopse = ( tensors_to_transpose = (
"dense.weight", "dense.weight",
"attention.self.query", "attention.self.query",
"attention.self.key", "attention.self.key",
...@@ -62,34 +62,34 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s ...@@ -62,34 +62,34 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir) os.makedirs(ckpt_dir)
session = tf.Session()
state_dict = model.state_dict() state_dict = model.state_dict()
tf_vars = []
def to_tf_var_name(name:str): def to_tf_var_name(name:str):
for patt, repl in iter(var_map): for patt, repl in iter(var_map):
name = name.replace(patt, repl) name = name.replace(patt, repl)
return 'bert/{}'.format(name) return 'bert/{}'.format(name)
def assign_tf_var(tensor:np.ndarray, name:str): def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session):
tmp_var = tf.Variable(initial_value=tensor) tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tmp_var.dtype, shape=tmp_var.shape, name=name) tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
op = tf.assign(ref=tf_var, value=tmp_var) session.run(tf.variables_initializer([tf_var]))
session.run(tf.variables_initializer([tmp_var, tf_var])) session.run(tf_var)
session.run(fetches=[op, tf_var])
return tf_var return tf_var
for var_name in state_dict: tf.reset_default_graph()
tf_name = to_tf_var_name(var_name) with tf.Session() as session:
torch_tensor = state_dict[var_name].numpy() for var_name in state_dict:
if any([x in var_name for x in tensors_to_transopse]): tf_name = to_tf_var_name(var_name)
torch_tensor = torch_tensor.T torch_tensor = state_dict[var_name].numpy()
tf_tensor = assign_tf_var(tensor=torch_tensor, name=tf_name) if any([x in var_name for x in tensors_to_transpose]):
tf_vars.append(tf_tensor) torch_tensor = torch_tensor.T
print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name)))) tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
tf.keras.backend.set_value(tf_var, torch_tensor)
saver = tf.train.Saver(tf_vars) tf_weight = session.run(tf_var)
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
saver = tf.train.Saver(tf.trainable_variables())
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
def main(raw_args=None): def main(raw_args=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment