Unverified Commit c9c5436c authored by Lukas Kuhn's avatar Lukas Kuhn Committed by GitHub
Browse files

download_from_original_stable_diffusion_ckpt initializes correct default pipeline for SDXL (#5784)

* feat: sdxl will be automatically detected as pipeline_class

* fix: formatting

* fix: formatting with black

* fix: import pipeline wrongly sorted
parent 9c8eca70
...@@ -1232,13 +1232,11 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1232,13 +1232,11 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline, StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline, StableUnCLIPPipeline,
) )
if pipeline_class is None:
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
if prediction_type == "v-prediction": if prediction_type == "v-prediction":
prediction_type = "v_prediction" prediction_type = "v_prediction"
...@@ -1333,6 +1331,13 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1333,6 +1331,13 @@ def download_from_original_stable_diffusion_ckpt(
if image_size is None: if image_size is None:
image_size = 1024 image_size = 1024
if pipeline_class is None:
# Check if we have a SDXL or SD model and initialize default pipeline
if model_type not in ["SDXL", "SDXL-Refiner"]:
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
else:
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9 num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment