Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
af0e4b7b
Unverified
Commit
af0e4b7b
authored
Jul 24, 2024
by
Marc Sun
Committed by
GitHub
Jul 24, 2024
Browse files
Fix float8_e4m3fn in modeling_utils (#32193)
* Fix float8_e4m3fn in modeling_utils * style * fix * comment
parent
1392a686
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+5
-2
No files found.
src/transformers/modeling_utils.py
View file @
af0e4b7b
...
@@ -855,6 +855,8 @@ def _load_state_dict_into_meta_model(
...
@@ -855,6 +855,8 @@ def _load_state_dict_into_meta_model(
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
is_torch_e4m3fn_available
=
hasattr
(
torch
,
"float8_e4m3fn"
)
for
param_name
,
param
in
state_dict
.
items
():
for
param_name
,
param
in
state_dict
.
items
():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if
param_name
not
in
loaded_state_dict_keys
or
param_name
not
in
expected_keys
:
if
param_name
not
in
loaded_state_dict_keys
or
param_name
not
in
expected_keys
:
...
@@ -866,9 +868,10 @@ def _load_state_dict_into_meta_model(
...
@@ -866,9 +868,10 @@ def _load_state_dict_into_meta_model(
module_name
=
param_name
module_name
=
param_name
set_module_kwargs
=
{}
set_module_kwargs
=
{}
# We convert floating dtypes to the `dtype` passed
. We
want to keep the buffers/params
# We convert floating dtypes to the `dtype` passed
except for float8_e4m3fn type. We also
want to keep the buffers/params
# in int/uint/bool and not cast them.
# in int/uint/bool and not cast them.
if
dtype
is
not
None
and
torch
.
is_floating_point
(
param
)
and
param
.
dtype
!=
torch
.
float8_e4m3fn
:
is_param_float8_e4m3fn
=
is_torch_e4m3fn_available
and
param
.
dtype
==
torch
.
float8_e4m3fn
if
dtype
is
not
None
and
torch
.
is_floating_point
(
param
)
and
not
is_param_float8_e4m3fn
:
if
(
if
(
keep_in_fp32_modules
is
not
None
keep_in_fp32_modules
is
not
None
and
any
(
and
any
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment