Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
6d6a08f1
Unverified
Commit
6d6a08f1
authored
Sep 13, 2023
by
Patrick von Platen
Committed by
GitHub
Sep 13, 2023
Browse files
[Flax->PT] Fix flaky testing (#5011)
fix flaky flax class name
parent
34bfe98e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
5 deletions
+4
-5
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+4
-5
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
6d6a08f1
...
@@ -343,9 +343,7 @@ def _get_pipeline_class(
...
@@ -343,9 +343,7 @@ def _get_pipeline_class(
diffusers_module
=
importlib
.
import_module
(
class_obj
.
__module__
.
split
(
"."
)[
0
])
diffusers_module
=
importlib
.
import_module
(
class_obj
.
__module__
.
split
(
"."
)[
0
])
class_name
=
config
[
"_class_name"
]
class_name
=
config
[
"_class_name"
]
class_name
=
class_name
[
4
:]
if
class_name
.
startswith
(
"Flax"
)
else
class_name
if
class_name
.
startswith
(
"Flax"
):
class_name
=
class_name
[
4
:]
pipeline_cls
=
getattr
(
diffusers_module
,
class_name
)
pipeline_cls
=
getattr
(
diffusers_module
,
class_name
)
...
@@ -1083,8 +1081,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -1083,8 +1081,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 6. Load each module in the pipeline
# 6. Load each module in the pipeline
for
name
,
(
library_name
,
class_name
)
in
tqdm
(
init_dict
.
items
(),
desc
=
"Loading pipeline components..."
):
for
name
,
(
library_name
,
class_name
)
in
tqdm
(
init_dict
.
items
(),
desc
=
"Loading pipeline components..."
):
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if
class_name
.
startswith
(
"Flax"
):
class_name
=
class_name
[
4
:]
if
class_name
.
startswith
(
"Flax"
)
else
class_name
class_name
=
class_name
[
4
:]
# 6.2 Define all importable classes
# 6.2 Define all importable classes
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
...
@@ -1611,6 +1608,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -1611,6 +1608,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# retrieve pipeline class from local file
# retrieve pipeline class from local file
cls_name
=
cls
.
load_config
(
os
.
path
.
join
(
cached_folder
,
"model_index.json"
)).
get
(
"_class_name"
,
None
)
cls_name
=
cls
.
load_config
(
os
.
path
.
join
(
cached_folder
,
"model_index.json"
)).
get
(
"_class_name"
,
None
)
cls_name
=
cls_name
[
4
:]
if
cls_name
.
startswith
(
"Flax"
)
else
cls_name
pipeline_class
=
getattr
(
diffusers
,
cls_name
,
None
)
pipeline_class
=
getattr
(
diffusers
,
cls_name
,
None
)
if
pipeline_class
is
not
None
and
pipeline_class
.
_load_connected_pipes
:
if
pipeline_class
is
not
None
and
pipeline_class
.
_load_connected_pipes
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment