Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4b8880a3
Unverified
Commit
4b8880a3
authored
Sep 22, 2022
by
Mishig Davaadorj
Committed by
GitHub
Sep 22, 2022
Browse files
Make flax from_pretrained work with local subfolder (#608)
parent
dd350c8a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
9 deletions
+14
-9
src/diffusers/modeling_flax_utils.py
src/diffusers/modeling_flax_utils.py
+14
-9
No files found.
src/diffusers/modeling_flax_utils.py
View file @
4b8880a3
...
@@ -310,26 +310,31 @@ class FlaxModelMixin:
...
@@ -310,26 +310,31 @@ class FlaxModelMixin:
)
)
# Load model
# Load model
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
pretrained_path_with_subfolder
=
(
pretrained_model_name_or_path
if
subfolder
is
None
else
os
.
path
.
join
(
pretrained_model_name_or_path
,
subfolder
)
)
if
os
.
path
.
isdir
(
pretrained_path_with_subfolder
):
if
from_pt
:
if
from_pt
:
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_
model_name_or_path
,
WEIGHTS_NAME
)):
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_
path_with_subfolder
,
WEIGHTS_NAME
)):
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"Error no file named
{
WEIGHTS_NAME
}
found in directory
{
pretrained_
model_name_or_path
}
"
f
"Error no file named
{
WEIGHTS_NAME
}
found in directory
{
pretrained_
path_with_subfolder
}
"
)
)
model_file
=
os
.
path
.
join
(
pretrained_
model_name_or_path
,
WEIGHTS_NAME
)
model_file
=
os
.
path
.
join
(
pretrained_
path_with_subfolder
,
WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_
model_name_or_path
,
FLAX_WEIGHTS_NAME
)):
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_
path_with_subfolder
,
FLAX_WEIGHTS_NAME
)):
# Load from a Flax checkpoint
# Load from a Flax checkpoint
model_file
=
os
.
path
.
join
(
pretrained_
model_name_or_path
,
FLAX_WEIGHTS_NAME
)
model_file
=
os
.
path
.
join
(
pretrained_
path_with_subfolder
,
FLAX_WEIGHTS_NAME
)
# Check if pytorch weights exist instead
# Check if pytorch weights exist instead
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_
model_name_or_path
,
WEIGHTS_NAME
)):
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_
path_with_subfolder
,
WEIGHTS_NAME
)):
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"
{
WEIGHTS_NAME
}
file found in directory
{
pretrained_
model_name_or_path
}
. Please load the model"
f
"
{
WEIGHTS_NAME
}
file found in directory
{
pretrained_
path_with_subfolder
}
. Please load the model"
" using `from_pt=True`."
" using `from_pt=True`."
)
)
else
:
else
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"Error no file named
{
FLAX_WEIGHTS_NAME
}
or
{
WEIGHTS_NAME
}
found in directory "
f
"Error no file named
{
FLAX_WEIGHTS_NAME
}
or
{
WEIGHTS_NAME
}
found in directory "
f
"
{
pretrained_
model_name_or_path
}
."
f
"
{
pretrained_
path_with_subfolder
}
."
)
)
else
:
else
:
try
:
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment