Unverified Commit accad48e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[ `T5`] fix fp16 loading issue (#20878)

* fix fp16 loading issue

* add backward compatibility

* better refactor

* better readability

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
parent 47146721
......@@ -609,6 +609,7 @@ def _load_state_dict_into_meta_model(
param_name = param_name[len(start_prefix) :]
module_name = param_name
set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
......@@ -619,6 +620,11 @@ def _load_state_dict_into_meta_model(
and dtype == torch.float16
):
param = param.to(torch.float32)
# For backward compatibility with older versions of `accelerate`
# TODO: @sgugger replace this check with version check at the next `accelerate` release
if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters):
set_module_kwargs["dtype"] = torch.float32
else:
param = param.to(dtype)
......@@ -634,6 +640,8 @@ def _load_state_dict_into_meta_model(
if old_param is not None:
param = param.to(old_param.dtype)
set_module_kwargs["value"] = param
if device_map is None:
param_device = "cpu"
else:
......@@ -651,7 +659,8 @@ def _load_state_dict_into_meta_model(
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif not load_in_8bit:
set_module_tensor_to_device(model, param_name, param_device, value=param)
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment