Unverified Commit 4099817b authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Updated _load_pretrained_model_low_mem to check if keys are in the state_dict (#16643)

* Updated _load_pretrained_model_low_mem to check if keys are in the stored state_dict

* update after conversions
parent 389f6615
......@@ -2162,13 +2162,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
del state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment