Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
dec18c86
Unverified
Commit
dec18c86
authored
Oct 21, 2022
by
Suraj Patil
Committed by
GitHub
Oct 21, 2022
Browse files
[Flax] dont warn for bf16 weights (#923)
dont warn for bf16 weights
parent
25dfd0f8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
23 deletions
+0
-23
src/diffusers/modeling_flax_utils.py
src/diffusers/modeling_flax_utils.py
+0
-23
No files found.
src/diffusers/modeling_flax_utils.py
View file @
dec18c86
...
@@ -482,29 +482,6 @@ class FlaxModelMixin:
...
@@ -482,29 +482,6 @@ class FlaxModelMixin:
" training."
" training."
)
)
# dictionary of key: dtypes for the model params
param_dtypes
=
jax
.
tree_map
(
lambda
x
:
x
.
dtype
,
state
)
# extract keys of parameters not in jnp.float32
fp16_params
=
[
k
for
k
in
param_dtypes
if
param_dtypes
[
k
]
==
jnp
.
float16
]
bf16_params
=
[
k
for
k
in
param_dtypes
if
param_dtypes
[
k
]
==
jnp
.
bfloat16
]
# raise a warning if any of the parameters are not in jnp.float32
if
len
(
fp16_params
)
>
0
:
logger
.
warning
(
f
"Some of the weights of
{
model
.
__class__
.
__name__
}
were initialized in float16 precision from "
f
"the model checkpoint at
{
pretrained_model_name_or_path
}
:
\n
{
fp16_params
}
\n
"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)
if
len
(
bf16_params
)
>
0
:
logger
.
warning
(
f
"Some of the weights of
{
model
.
__class__
.
__name__
}
were initialized in bfloat16 precision from "
f
"the model checkpoint at
{
pretrained_model_name_or_path
}
:
\n
{
bf16_params
}
\n
"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)
return
model
,
unflatten_dict
(
state
)
return
model
,
unflatten_dict
(
state
)
def
save_pretrained
(
def
save_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment