Unverified Commit 8bf6d28c authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

made _load_pretrained_model_low_mem static + bug fix (#16548)

parent 02214cb3
...@@ -2103,8 +2103,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2103,8 +2103,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return retrieved_modules return retrieved_modules
@classmethod @staticmethod
def _load_pretrained_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file): def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file):
""" """
This is an experimental function that loads the model using ~1.x model size CPU memory This is an experimental function that loads the model using ~1.x model size CPU memory
...@@ -2159,7 +2159,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2159,7 +2159,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
resolved_archive_file = [resolved_archive_file] resolved_archive_file = [resolved_archive_file]
for archive_file in resolved_archive_file: for archive_file in resolved_archive_file:
state_dict = torch.load(resolved_archive_file, map_location="cpu") state_dict = torch.load(archive_file, map_location="cpu")
# materialize state_dict entries one by one on CPU # materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys: for k in loaded_state_dict_keys:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment