Unverified Commit 83d2d745 authored by Nouamane Tazi's avatar Nouamane Tazi Committed by GitHub
Browse files

fix loading from pretrained for sharded model with `torch_dtype="auto" (#18061)

parent 7996ef74
...@@ -2073,7 +2073,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2073,7 +2073,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif not is_sharded: elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict) torch_dtype = get_state_dict_dtype(state_dict)
else: else:
one_state_dict = load_state_dict(resolved_archive_file) one_state_dict = load_state_dict(resolved_archive_file[0])
torch_dtype = get_state_dict_dtype(one_state_dict) torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory del one_state_dict # free CPU memory
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment