"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ea5dda22904e807437e07cdfa5d3d89f17a590f7"
Unverified Commit 344e2664 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix dtype in radnomly initialized head (#19690)

parent 07f66902
...@@ -2446,9 +2446,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2446,9 +2446,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
param = model_state_dict[key] param = model_state_dict[key]
if param.device == torch.device("meta"): if param.device == torch.device("meta"):
if not load_in_8bit: if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size())) set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
else: else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size())) set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init: if _fast_init:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment