Unverified Commit 8cdcdd9e authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add flux inpaint + img2img + controlnet to auto pipeline (#9367)

parent d269cc8a
...@@ -29,7 +29,7 @@ from .controlnet import ( ...@@ -29,7 +29,7 @@ from .controlnet import (
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetPipeline,
) )
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import FluxPipeline from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
from .hunyuandit import HunyuanDiTPipeline from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import ( from .kandinsky import (
KandinskyCombinedPipeline, KandinskyCombinedPipeline,
...@@ -108,6 +108,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -108,6 +108,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("pixart-sigma-pag", PixArtSigmaPAGPipeline), ("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline), ("auraflow", AuraFlowPipeline),
("flux", FluxPipeline), ("flux", FluxPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline), ("lumina", LuminaText2ImgPipeline),
] ]
) )
...@@ -126,6 +127,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -126,6 +127,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
] ]
) )
...@@ -140,6 +142,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( ...@@ -140,6 +142,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
] ]
) )
...@@ -660,12 +663,17 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -660,12 +663,17 @@ class AutoPipelineForImage2Image(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"] orig_class_name = config["_class_name"]
# the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint)
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs: if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs: if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag") enable_pag = kwargs.pop("enable_pag")
if enable_pag: if enable_pag:
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
...@@ -952,14 +960,17 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -952,14 +960,17 @@ class AutoPipelineForInpainting(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"] orig_class_name = config["_class_name"]
# The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint)
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs: if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs: if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag") enable_pag = kwargs.pop("enable_pag")
if enable_pag: if enable_pag:
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace)
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs} kwargs = {**load_config_kwargs, **kwargs}
......
...@@ -235,9 +235,32 @@ class AutoPipelineFastTest(unittest.TestCase): ...@@ -235,9 +235,32 @@ class AutoPipelineFastTest(unittest.TestCase):
pipe = AutoPipelineForImage2Image.from_pretrained(repo) pipe = AutoPipelineForImage2Image.from_pretrained(repo)
assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline" assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline"
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet)
assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline"
pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True)
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline"
pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True)
assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline"
def test_from_pretrained_img2img_refiner(self):
repo = "hf-internal-testing/tiny-stable-diffusion-xl-refiner-pipe"
pipe = AutoPipelineForImage2Image.from_pretrained(repo)
assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline"
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet)
assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline"
pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True) pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True)
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline"
pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True)
assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline"
def test_from_pipe_pag_img2img(self): def test_from_pipe_pag_img2img(self):
# test from tableDiffusionXLPAGImg2ImgPipeline # test from tableDiffusionXLPAGImg2ImgPipeline
pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
...@@ -265,6 +288,16 @@ class AutoPipelineFastTest(unittest.TestCase): ...@@ -265,6 +288,16 @@ class AutoPipelineFastTest(unittest.TestCase):
pipe_pag = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True) pipe_pag = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True)
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline"
def test_from_pretrained_inpaint_from_inpaint(self):
repo = "hf-internal-testing/tiny-stable-diffusion-xl-inpaint-pipe"
pipe = AutoPipelineForInpainting.from_pretrained(repo)
assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline"
# make sure you can use pag with inpaint-specific pipeline
pipe = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True)
assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline"
def test_from_pipe_pag_inpaint(self): def test_from_pipe_pag_inpaint(self):
# test from tableDiffusionXLPAGInpaintPipeline # test from tableDiffusionXLPAGInpaintPipeline
pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment