Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f8100600
".github/git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "4e79c0db2143e565bd176a832c5d3332f7ebd4f5"
Unverified
Commit
f8100600
authored
Sep 21, 2022
by
Mishig Davaadorj
Committed by
GitHub
Sep 21, 2022
Browse files
Fix flax from_pretrained pytorch weight check (#603)
parent
fb2fbab1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
src/diffusers/modeling_flax_utils.py
src/diffusers/modeling_flax_utils.py
+3
-3
No files found.
src/diffusers/modeling_flax_utils.py
View file @
f8100600
...
@@ -307,7 +307,7 @@ class FlaxModelMixin:
...
@@ -307,7 +307,7 @@ class FlaxModelMixin:
# Load model
# Load model
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
from_pt
:
if
from_pt
:
if
not
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
):
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
)
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"Error no file named
{
WEIGHTS_NAME
}
found in directory
{
pretrained_model_name_or_path
}
"
f
"Error no file named
{
WEIGHTS_NAME
}
found in directory
{
pretrained_model_name_or_path
}
"
)
)
...
@@ -315,8 +315,8 @@ class FlaxModelMixin:
...
@@ -315,8 +315,8 @@ class FlaxModelMixin:
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)):
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)):
# Load from a Flax checkpoint
# Load from a Flax checkpoint
model_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)
model_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)
#
At this stage we don't have a weight file so we will raise an error.
#
Check if pytorch weights exist instead
elif
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
):
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
)
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"
{
WEIGHTS_NAME
}
file found in directory
{
pretrained_model_name_or_path
}
. Please load the model"
f
"
{
WEIGHTS_NAME
}
file found in directory
{
pretrained_model_name_or_path
}
. Please load the model"
" using `from_pt=True`."
" using `from_pt=True`."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment