"docs/vscode:/vscode.git/clone" did not exist on "30ed3adf474aaf2972ab56f5624089bc24a6adf3"
Unverified Commit 7586a1a3 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix dtype of weights in from_pretrained when device_map is set (#20602)

parent bf9a5882
......@@ -593,13 +593,22 @@ def _load_state_dict_into_meta_model(
module_name = param_name
# We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
if is_safetensors and dtype is None and torch.is_floating_point(param):
param = param.to(torch.float32)
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
if dtype is None:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break
if old_param is not None:
param = param.to(old_param.dtype)
if device_map is None:
param_device = "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment