Unverified Commit 5ed761a6 authored by hlky's avatar hlky Committed by GitHub
Browse files

Add ControlNetUnion to AutoPipeline from_pretrained (#10219)

parent 2f023d7b
...@@ -18,6 +18,7 @@ from collections import OrderedDict ...@@ -18,6 +18,7 @@ from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline from .cogview3 import CogView3PlusPipeline
...@@ -28,6 +29,9 @@ from .controlnet import ( ...@@ -28,6 +29,9 @@ from .controlnet import (
StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetPipeline,
StableDiffusionXLControlNetUnionImg2ImgPipeline,
StableDiffusionXLControlNetUnionInpaintPipeline,
StableDiffusionXLControlNetUnionPipeline,
) )
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import ( from .flux import (
...@@ -108,6 +112,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -108,6 +112,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("kandinsky3", Kandinsky3Pipeline), ("kandinsky3", Kandinsky3Pipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
("wuerstchen", WuerstchenCombinedPipeline), ("wuerstchen", WuerstchenCombinedPipeline),
("cascade", StableCascadeCombinedPipeline), ("cascade", StableCascadeCombinedPipeline),
("lcm", LatentConsistencyModelPipeline), ("lcm", LatentConsistencyModelPipeline),
...@@ -139,6 +144,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -139,6 +144,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline), ("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline),
...@@ -158,6 +164,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( ...@@ -158,6 +164,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline), ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline), ("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline), ("flux-controlnet", FluxControlNetInpaintPipeline),
...@@ -396,6 +403,9 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -396,6 +403,9 @@ class AutoPipelineForText2Image(ConfigMixin):
orig_class_name = config["_class_name"] orig_class_name = config["_class_name"]
if "controlnet" in kwargs: if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
else:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
if "enable_pag" in kwargs: if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag") enable_pag = kwargs.pop("enable_pag")
...@@ -688,6 +698,9 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -688,6 +698,9 @@ class AutoPipelineForImage2Image(ConfigMixin):
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs: if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
else:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs: if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag") enable_pag = kwargs.pop("enable_pag")
...@@ -985,6 +998,9 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -985,6 +998,9 @@ class AutoPipelineForInpainting(ConfigMixin):
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs: if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
else:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs: if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag") enable_pag = kwargs.pop("enable_pag")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment