"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5164ea91a7b4d35cb03867233527fa383a651775"
Unverified Commit 344e2664 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix dtype in radnomly initialized head (#19690)

parent 07f66902
......@@ -2446,9 +2446,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
param = model_state_dict[key]
if param.device == torch.device("meta"):
if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment