Unverified Commit 8cdcdd9e authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add flux inpaint + img2img + controlnet to auto pipeline (#9367)

parent d269cc8a
......@@ -29,7 +29,7 @@ from .controlnet import (
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import FluxPipeline
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
......@@ -108,6 +108,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("flux", FluxPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
]
)
......@@ -126,6 +127,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
]
)
......@@ -140,6 +142,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
]
)
......@@ -660,12 +663,17 @@ class AutoPipelineForImage2Image(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
# the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint)
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
......@@ -952,14 +960,17 @@ class AutoPipelineForInpainting(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
# The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint)
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace)
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
......
......@@ -235,9 +235,32 @@ class AutoPipelineFastTest(unittest.TestCase):
pipe = AutoPipelineForImage2Image.from_pretrained(repo)
assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline"
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet)
assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline"
pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True)
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline"
pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True)
assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline"
def test_from_pretrained_img2img_refiner(self):
repo = "hf-internal-testing/tiny-stable-diffusion-xl-refiner-pipe"
pipe = AutoPipelineForImage2Image.from_pretrained(repo)
assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline"
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet)
assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline"
pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True)
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline"
pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True)
assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline"
def test_from_pipe_pag_img2img(self):
# test from tableDiffusionXLPAGImg2ImgPipeline
pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
......@@ -265,6 +288,16 @@ class AutoPipelineFastTest(unittest.TestCase):
pipe_pag = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True)
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline"
def test_from_pretrained_inpaint_from_inpaint(self):
repo = "hf-internal-testing/tiny-stable-diffusion-xl-inpaint-pipe"
pipe = AutoPipelineForInpainting.from_pretrained(repo)
assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline"
# make sure you can use pag with inpaint-specific pipeline
pipe = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True)
assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline"
def test_from_pipe_pag_inpaint(self):
# test from tableDiffusionXLPAGInpaintPipeline
pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment