"vscode:/vscode.git/clone" did not exist on "a98f6a1da012ca7847e4dceb3ffcedfd75a77b08"
Unverified Commit accad48e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[ `T5`] fix fp16 loading issue (#20878)

* fix fp16 loading issue

* add backward compatibility

* better refactor

* better readability

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
parent 47146721
...@@ -609,6 +609,7 @@ def _load_state_dict_into_meta_model( ...@@ -609,6 +609,7 @@ def _load_state_dict_into_meta_model(
param_name = param_name[len(start_prefix) :] param_name = param_name[len(start_prefix) :]
module_name = param_name module_name = param_name
set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them. # in int/uint/bool and not cast them.
...@@ -619,6 +620,11 @@ def _load_state_dict_into_meta_model( ...@@ -619,6 +620,11 @@ def _load_state_dict_into_meta_model(
and dtype == torch.float16 and dtype == torch.float16
): ):
param = param.to(torch.float32) param = param.to(torch.float32)
# For backward compatibility with older versions of `accelerate`
# TODO: @sgugger replace this check with version check at the next `accelerate` release
if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters):
set_module_kwargs["dtype"] = torch.float32
else: else:
param = param.to(dtype) param = param.to(dtype)
...@@ -634,6 +640,8 @@ def _load_state_dict_into_meta_model( ...@@ -634,6 +640,8 @@ def _load_state_dict_into_meta_model(
if old_param is not None: if old_param is not None:
param = param.to(old_param.dtype) param = param.to(old_param.dtype)
set_module_kwargs["value"] = param
if device_map is None: if device_map is None:
param_device = "cpu" param_device = "cpu"
else: else:
...@@ -651,7 +659,8 @@ def _load_state_dict_into_meta_model( ...@@ -651,7 +659,8 @@ def _load_state_dict_into_meta_model(
elif param_device == "cpu" and state_dict_index is not None: elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif not load_in_8bit: elif not load_in_8bit:
set_module_tensor_to_device(model, param_name, param_device, value=param) # For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else: else:
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param) set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment