"src/vscode:/vscode.git/clone" did not exist on "2de9e2df368241cf13f859cf51514cea4e53aed5"
Unverified Commit c9c5436c authored by Lukas Kuhn's avatar Lukas Kuhn Committed by GitHub
Browse files

download_from_original_stable_diffusion_ckpt initializes correct default pipeline for SDXL (#5784)

* feat: sdxl will be automatically detected as pipeline_class

* fix: formatting

* fix: formatting with black

* fix: import pipeline wrongly sorted
parent 9c8eca70
......@@ -1232,13 +1232,11 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
if pipeline_class is None:
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
if prediction_type == "v-prediction":
prediction_type = "v_prediction"
......@@ -1333,6 +1331,13 @@ def download_from_original_stable_diffusion_ckpt(
if image_size is None:
image_size = 1024
if pipeline_class is None:
# Check if we have a SDXL or SD model and initialize default pipeline
if model_type not in ["SDXL", "SDXL-Refiner"]:
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
else:
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment