Commit 00a09ed4 authored by Arthur Zucker's avatar Arthur Zucker
Browse files

fix 😭

parent 8e9a2207
...@@ -600,7 +600,9 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): ...@@ -600,7 +600,9 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# so we need to apply the function recursively. # so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""): def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) unexpected_keys = []
missing_keys = []
args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this # Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict # state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0: if len([key for key in state_dict if key.startswith(prefix)]) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment