Unverified Commit bc9332b5 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[TF] from_pt should respect authorized_unexpected_keys (#8056)

parent 7ff7c493
......@@ -208,6 +208,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if tf_model.authorized_missing_keys is not None:
for pat in tf_model.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if tf_model.authorized_unexpected_keys is not None:
for pat in tf_model.authorized_unexpected_keys:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment