Commit edfd965a authored by David Pollack's avatar David Pollack
Browse files

fix convert_to_tf

parent 46cc9dd2
......@@ -72,11 +72,11 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
return 'bert/{}'.format(name)
def assign_tf_var(tensor:np.ndarray, name:str):
tmp_var = tf.Variable(initial_value=tensor)
tf_var = tf.get_variable(dtype=tmp_var.dtype, shape=tmp_var.shape, name=name)
op = tf.assign(ref=tf_var, value=tmp_var)
session.run(tf.variables_initializer([tmp_var, tf_var]))
session.run(fetches=[op, tf_var])
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name)
session.run(tf.variables_initializer([tf_var]))
tf.keras.backend.set_value(tf_var, tensor)
session.run(tf_var)
return tf_var
for var_name in state_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment