Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7586a1a3
"docs/vscode:/vscode.git/clone" did not exist on "30ed3adf474aaf2972ab56f5624089bc24a6adf3"
Unverified
Commit
7586a1a3
authored
Dec 06, 2022
by
Sylvain Gugger
Committed by
GitHub
Dec 06, 2022
Browse files
Fix dtype of weights in from_pretrained when device_map is set (#20602)
parent
bf9a5882
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
4 deletions
+13
-4
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+13
-4
No files found.
src/transformers/modeling_utils.py
View file @
7586a1a3
...
...
@@ -593,13 +593,22 @@ def _load_state_dict_into_meta_model(
module_name
=
param_name
# We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params
# We convert floating dtypes to the `dtype` passed.
We want to keep the buffers/params
# in int/uint/bool and not cast them.
if
dtype
is
not
None
and
torch
.
is_floating_point
(
param
):
param
=
param
.
to
(
dtype
)
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
if
is_safetensors
and
dtype
is
None
and
torch
.
is_floating_point
(
param
):
param
=
param
.
to
(
torch
.
float32
)
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
if
dtype
is
None
:
old_param
=
model
splits
=
param_name
.
split
(
"."
)
for
split
in
splits
:
old_param
=
getattr
(
old_param
,
split
)
if
old_param
is
None
:
break
if
old_param
is
not
None
:
param
=
param
.
to
(
old_param
.
dtype
)
if
device_map
is
None
:
param_device
=
"cpu"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment