"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4f1b31c2ee2822618d8433a71627ec18e9f2e2d3"
Unverified Commit 7586a1a3 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix dtype of weights in from_pretrained when device_map is set (#20602)

parent bf9a5882
...@@ -593,13 +593,22 @@ def _load_state_dict_into_meta_model( ...@@ -593,13 +593,22 @@ def _load_state_dict_into_meta_model(
module_name = param_name module_name = param_name
# We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them. # in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param): if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype) param = param.to(dtype)
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
if is_safetensors and dtype is None and torch.is_floating_point(param): # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
param = param.to(torch.float32) if dtype is None:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break
if old_param is not None:
param = param.to(old_param.dtype)
if device_map is None: if device_map is None:
param_device = "cpu" param_device = "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment