"src/diffusers/models/consistency_decoder_vae.py" did not exist on "43346adc1ffa9051fc71be9af33fd982ee14c383"
Unverified Commit 6bbee104 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure Flax pipelines can be loaded into PyTorch (#4971)

* Make sure Flax pipelines can be loaded into PyTorch

* add test

* Update src/diffusers/pipelines/pipeline_utils.py
parent 2c60f7d1
......@@ -342,7 +342,12 @@ def _get_pipeline_class(
return class_obj
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
pipeline_cls = getattr(diffusers_module, config["_class_name"])
class_name = config["_class_name"]
if class_name.startswith("Flax"):
class_name = class_name[4:]
pipeline_cls = getattr(diffusers_module, class_name)
if load_connected_pipeline:
from .auto_pipeline import _get_connected_pipeline
......
......@@ -57,7 +57,7 @@ from diffusers import (
UniPCMultistepScheduler,
logging,
)
from diffusers.pipelines.pipeline_utils import variant_compatible_siblings
from diffusers.pipelines.pipeline_utils import _get_pipeline_class, variant_compatible_siblings
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import (
CONFIG_NAME,
......@@ -805,6 +805,14 @@ class DownloadTests(unittest.TestCase):
assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files)
assert len(files) == 14
def test_get_pipeline_class_from_flax(self):
flax_config = {"_class_name": "FlaxStableDiffusionPipeline"}
config = {"_class_name": "StableDiffusionPipeline"}
# when loading a PyTorch Pipeline from a FlaxPipeline `model_index.json`, e.g.: https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-lms-pipe/blob/7a9063578b325779f0f1967874a6771caa973cad/model_index.json#L2
# we need to make sure that we don't load the Flax Pipeline class, but instead the PyTorch pipeline class
assert _get_pipeline_class(DiffusionPipeline, flax_config) == _get_pipeline_class(DiffusionPipeline, config)
class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment