Unverified Commit 83d2d745 authored by Nouamane Tazi's avatar Nouamane Tazi Committed by GitHub
Browse files

fix loading from pretrained for sharded model with `torch_dtype="auto" (#18061)

parent 7996ef74
......@@ -2073,7 +2073,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict)
else:
one_state_dict = load_state_dict(resolved_archive_file)
one_state_dict = load_state_dict(resolved_archive_file[0])
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment