Unverified Commit 6bbee104 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure Flax pipelines can be loaded into PyTorch (#4971)

* Make sure Flax pipelines can be loaded into PyTorch

* add test

* Update src/diffusers/pipelines/pipeline_utils.py
parent 2c60f7d1
...@@ -342,7 +342,12 @@ def _get_pipeline_class( ...@@ -342,7 +342,12 @@ def _get_pipeline_class(
return class_obj return class_obj
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
pipeline_cls = getattr(diffusers_module, config["_class_name"]) class_name = config["_class_name"]
if class_name.startswith("Flax"):
class_name = class_name[4:]
pipeline_cls = getattr(diffusers_module, class_name)
if load_connected_pipeline: if load_connected_pipeline:
from .auto_pipeline import _get_connected_pipeline from .auto_pipeline import _get_connected_pipeline
......
...@@ -57,7 +57,7 @@ from diffusers import ( ...@@ -57,7 +57,7 @@ from diffusers import (
UniPCMultistepScheduler, UniPCMultistepScheduler,
logging, logging,
) )
from diffusers.pipelines.pipeline_utils import variant_compatible_siblings from diffusers.pipelines.pipeline_utils import _get_pipeline_class, variant_compatible_siblings
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import ( from diffusers.utils import (
CONFIG_NAME, CONFIG_NAME,
...@@ -805,6 +805,14 @@ class DownloadTests(unittest.TestCase): ...@@ -805,6 +805,14 @@ class DownloadTests(unittest.TestCase):
assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files) assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files)
assert len(files) == 14 assert len(files) == 14
def test_get_pipeline_class_from_flax(self):
flax_config = {"_class_name": "FlaxStableDiffusionPipeline"}
config = {"_class_name": "StableDiffusionPipeline"}
# when loading a PyTorch Pipeline from a FlaxPipeline `model_index.json`, e.g.: https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-lms-pipe/blob/7a9063578b325779f0f1967874a6771caa973cad/model_index.json#L2
# we need to make sure that we don't load the Flax Pipeline class, but instead the PyTorch pipeline class
assert _get_pipeline_class(DiffusionPipeline, flax_config) == _get_pipeline_class(DiffusionPipeline, config)
class CustomPipelineTests(unittest.TestCase): class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self): def test_load_custom_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment