Unverified Commit d8b6f5d0 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

support AutoPipeline.from_pipe between a pipeline and its ControlNet pipeline counterpart (#4861)

add 
parent 30a5acc3
......@@ -366,6 +366,18 @@ class AutoPipelineForText2Image(ConfigMixin):
# derive the pipeline class to instantiate
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, original_cls_name)
if "controlnet" in kwargs:
if kwargs["controlnet"] is not None:
text_2_image_cls = _get_task_class(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
text_2_image_cls.__name__.replace("Pipeline", "ControlNetPipeline"),
)
else:
text_2_image_cls = _get_task_class(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
text_2_image_cls.__name__.replace("ControlNetPipeline", "Pipeline"),
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = _get_signature_keys(text_2_image_cls)
......@@ -631,6 +643,18 @@ class AutoPipelineForImage2Image(ConfigMixin):
# derive the pipeline class to instantiate
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, original_cls_name)
if "controlnet" in kwargs:
if kwargs["controlnet"] is not None:
image_2_image_cls = _get_task_class(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
image_2_image_cls.__name__.replace("Img2ImgPipeline", "ControlNetImg2ImgPipeline"),
)
else:
image_2_image_cls = _get_task_class(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
image_2_image_cls.__name__.replace("ControlNetImg2ImgPipeline", "Img2ImgPipeline"),
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = _get_signature_keys(image_2_image_cls)
......@@ -894,6 +918,18 @@ class AutoPipelineForInpainting(ConfigMixin):
# derive the pipeline class to instantiate
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, original_cls_name)
if "controlnet" in kwargs:
if kwargs["controlnet"] is not None:
inpainting_cls = _get_task_class(
AUTO_INPAINT_PIPELINES_MAPPING,
inpainting_cls.__name__.replace("InpaintPipeline", "ControlNetInpaintPipeline"),
)
else:
inpainting_cls = _get_task_class(
AUTO_INPAINT_PIPELINES_MAPPING,
inpainting_cls.__name__.replace("ControlNetInpaintPipeline", "InpaintPipeline"),
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = _get_signature_keys(inpainting_cls)
......
......@@ -108,6 +108,54 @@ class AutoPipelineFastTest(unittest.TestCase):
shutil.rmtree(tmpdirname.parent.parent)
def test_from_pipe_controlnet_text2img(self):
pipe = AutoPipelineForText2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe")
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet)
assert pipe.__class__.__name__ == "StableDiffusionControlNetPipeline"
assert "controlnet" in pipe.components
pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None)
assert pipe.__class__.__name__ == "StableDiffusionPipeline"
assert "controlnet" not in pipe.components
def test_from_pipe_controlnet_img2img(self):
pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe")
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe = AutoPipelineForImage2Image.from_pipe(pipe, controlnet=controlnet)
assert pipe.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
assert "controlnet" in pipe.components
pipe = AutoPipelineForImage2Image.from_pipe(pipe, controlnet=None)
assert pipe.__class__.__name__ == "StableDiffusionImg2ImgPipeline"
assert "controlnet" not in pipe.components
def test_from_pipe_controlnet_inpaint(self):
pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe = AutoPipelineForInpainting.from_pipe(pipe, controlnet=controlnet)
assert pipe.__class__.__name__ == "StableDiffusionControlNetInpaintPipeline"
assert "controlnet" in pipe.components
pipe = AutoPipelineForInpainting.from_pipe(pipe, controlnet=None)
assert pipe.__class__.__name__ == "StableDiffusionInpaintPipeline"
assert "controlnet" not in pipe.components
def test_from_pipe_controlnet_new_task(self):
pipe_text2img = AutoPipelineForText2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe_control_img2img = AutoPipelineForImage2Image.from_pipe(pipe_text2img, controlnet=controlnet)
assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
assert "controlnet" in pipe_control_img2img.components
pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_control_img2img, controlnet=None)
assert pipe_inpaint.__class__.__name__ == "StableDiffusionInpaintPipeline"
assert "controlnet" not in pipe_inpaint.components
@slow
class AutoPipelineIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment