Commit 851ef592 authored by Rémi Louf's avatar Rémi Louf
Browse files

add comment on recursive weights loading

parent 770b15b5
......@@ -383,6 +383,8 @@ class PreTrainedModel(nn.Module):
if metadata is not None:
state_dict._metadata = metadata
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment