Commit 11b13e94 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Add type to help my IDE out

parent 1ce3fb5c
...@@ -512,7 +512,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -512,7 +512,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively. # so we need to apply the function recursively.
def load(module, prefix=""): def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict( module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment