Unverified Commit ab62a23d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Let's not cast them all (#18471)



* add correct dtypes when checking for params dtype

* forward contrib credits

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarThomas Wang <24695242+thomasw21@users.noreply.github.com>

* more comments

- added more comments on why we cast only floating point parameters

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarsgugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarThomas Wang <24695242+thomasw21@users.noreply.github.com>
parent 499450ed
......@@ -543,8 +543,10 @@ def _load_state_dict_into_meta_model(
param_name = param_name[len(start_prefix) :]
module_name = param_name
# We convert floating dtypes to the `dtype` passed.
if dtype is not None and not str(param.dtype).startswith("torch.int"):
# We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
if device_map is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment