"examples/vscode:/vscode.git/clone" did not exist on "fbcd3ba6b27a2b019c7209aeb3073c41b72bff43"
Commit 73368963 authored by monologg's avatar monologg Committed by Lysandre Debut
Browse files

Fix importing unofficial TF models with extra optimizer weights

parent d7dabfef
...@@ -117,7 +117,13 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -117,7 +117,13 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
name = name.split("/") name = name.split("/")
# Ignore the gradients applied by the LAMB/ADAM optimizers. # Ignore the gradients applied by the LAMB/ADAM optimizers.
if "adam_m" in name or "adam_v" in name or "global_step" in name: if (
"adam_m" in name
or "adam_v" in name
or "AdamWeightDecayOptimizer" in name
or "AdamWeightDecayOptimizer_1" in name
or "global_step" in name
):
logger.info("Skipping {}".format("/".join(name))) logger.info("Skipping {}".format("/".join(name)))
continue continue
......
...@@ -86,7 +86,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -86,7 +86,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
name = name.split("/") name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name): if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name))) logger.info("Skipping {}".format("/".join(name)))
continue continue
pointer = model pointer = model
......
...@@ -79,7 +79,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): ...@@ -79,7 +79,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
name = txt_name.split("/") name = txt_name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name): if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name))) logger.info("Skipping {}".format("/".join(name)))
tf_weights.pop(txt_name, None) tf_weights.pop(txt_name, None)
continue continue
......
...@@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path): ...@@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
name = name.split("/") name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name): if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name))) logger.info("Skipping {}".format("/".join(name)))
continue continue
pointer = model pointer = model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment