Unverified Commit 2d91e3c3 authored by Yongliang Shen's avatar Yongliang Shen Committed by GitHub
Browse files

use original loaded keys to find mismatched keys (#16920)

parent d365f507
......@@ -2022,6 +2022,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return key.replace("gamma", "weight")
return key
original_loaded_keys = loaded_keys
loaded_keys = [_fix_key(key) for key in loaded_keys]
if len(prefix) > 0:
......@@ -2114,7 +2115,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
......@@ -2140,7 +2141,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment