Commit a309459b authored by Chris's avatar Chris
Browse files

fn change; pytorch_model_dir required=False

parent 69749f3f
...@@ -56,7 +56,6 @@ def convert_hf_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str): ...@@ -56,7 +56,6 @@ def convert_hf_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str):
name = name.replace('LayerNorm/weight', 'LayerNorm/gamma') name = name.replace('LayerNorm/weight', 'LayerNorm/gamma')
name = name.replace('LayerNorm/bias', 'LayerNorm/beta') name = name.replace('LayerNorm/bias', 'LayerNorm/beta')
name = name.replace('weight', 'kernel') name = name.replace('weight', 'kernel')
# name += ':0'
return 'bert/{}'.format(name) return 'bert/{}'.format(name)
def assign_tf_var(tensor:np.ndarray, name:str): def assign_tf_var(tensor:np.ndarray, name:str):
...@@ -86,7 +85,7 @@ if __name__ == "__main__": ...@@ -86,7 +85,7 @@ if __name__ == "__main__":
parser.add_argument("--pytorch_model_dir", parser.add_argument("--pytorch_model_dir",
default=None, default=None,
type=str, type=str,
required=True, required=False,
help="Directory containing pytorch model") help="Directory containing pytorch model")
parser.add_argument("--pytorch_model_name", parser.add_argument("--pytorch_model_name",
default=None, default=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment