Unverified Commit 12bfa97a authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[from_pretrained] refactor find_mismatched_keys (#16706)

parent 9f8bfe70
...@@ -1966,8 +1966,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1966,8 +1966,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"properly saved?" "properly saved?"
) )
if state_dict is not None: def _find_mismatched_keys(
# Whole checkpoint state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = [] mismatched_keys = []
if ignore_mismatched_sizes: if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys: for checkpoint_key in loaded_keys:
...@@ -1988,6 +1994,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1988,6 +1994,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
del state_dict[checkpoint_key] del state_dict[checkpoint_key]
return mismatched_keys
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
else: else:
# Sharded checkpoint # Sharded checkpoint
...@@ -1996,30 +2014,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1996,30 +2014,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
resolved_archive_file = [resolved_archive_file] resolved_archive_file = [resolved_archive_file]
error_msgs = [] error_msgs = []
mismatched_keys = []
for shard_file in resolved_archive_file: for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file) state_dict = load_state_dict(shard_file)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model. # matching the weights in the model.
mismatched_keys = [] mismatched_keys += _find_mismatched_keys(
if ignore_mismatched_sizes: state_dict,
for checkpoint_key in loaded_keys: model_state_dict,
model_key = checkpoint_key loaded_keys,
if remove_prefix_from_model: add_prefix_to_model,
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. remove_prefix_from_model,
model_key = f"{prefix}.{checkpoint_key}" ignore_mismatched_sizes,
elif add_prefix_to_model: )
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
if len(error_msgs) > 0: if len(error_msgs) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment