"vscode:/vscode.git/clone" did not exist on "eab4d8232d558f2e6bd7f7cc3d00e2e6e94f4e80"
Unverified Commit 6d6a08f1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax->PT] Fix flaky testing (#5011)

fix flaky flax class name
parent 34bfe98e
...@@ -343,9 +343,7 @@ def _get_pipeline_class( ...@@ -343,9 +343,7 @@ def _get_pipeline_class(
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
class_name = config["_class_name"] class_name = config["_class_name"]
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
if class_name.startswith("Flax"):
class_name = class_name[4:]
pipeline_cls = getattr(diffusers_module, class_name) pipeline_cls = getattr(diffusers_module, class_name)
...@@ -1083,8 +1081,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1083,8 +1081,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 6. Load each module in the pipeline # 6. Load each module in the pipeline
for name, (library_name, class_name) in tqdm(init_dict.items(), desc="Loading pipeline components..."): for name, (library_name, class_name) in tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"): class_name = class_name[4:] if class_name.startswith("Flax") else class_name
class_name = class_name[4:]
# 6.2 Define all importable classes # 6.2 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
...@@ -1611,6 +1608,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1611,6 +1608,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# retrieve pipeline class from local file # retrieve pipeline class from local file
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
cls_name = cls_name[4:] if cls_name.startswith("Flax") else cls_name
pipeline_class = getattr(diffusers, cls_name, None) pipeline_class = getattr(diffusers, cls_name, None)
if pipeline_class is not None and pipeline_class._load_connected_pipes: if pipeline_class is not None and pipeline_class._load_connected_pipes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment