Unverified Commit d1222064 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix a bug in `AutoPipeline.from_pipe()` when creating a controlnet pipeline...


fix a bug in `AutoPipeline.from_pipe()` when creating a controlnet pipeline from an existing controlnet (#5638)

fix
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent c84982a8
...@@ -372,7 +372,7 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -372,7 +372,7 @@ class AutoPipelineForText2Image(ConfigMixin):
if kwargs["controlnet"] is not None: if kwargs["controlnet"] is not None:
text_2_image_cls = _get_task_class( text_2_image_cls = _get_task_class(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING, AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
text_2_image_cls.__name__.replace("Pipeline", "ControlNetPipeline"), text_2_image_cls.__name__.replace("ControlNet", "").replace("Pipeline", "ControlNetPipeline"),
) )
else: else:
text_2_image_cls = _get_task_class( text_2_image_cls = _get_task_class(
...@@ -645,7 +645,9 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -645,7 +645,9 @@ class AutoPipelineForImage2Image(ConfigMixin):
if kwargs["controlnet"] is not None: if kwargs["controlnet"] is not None:
image_2_image_cls = _get_task_class( image_2_image_cls = _get_task_class(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
image_2_image_cls.__name__.replace("Img2ImgPipeline", "ControlNetImg2ImgPipeline"), image_2_image_cls.__name__.replace("ControlNet", "").replace(
"Img2ImgPipeline", "ControlNetImg2ImgPipeline"
),
) )
else: else:
image_2_image_cls = _get_task_class( image_2_image_cls = _get_task_class(
...@@ -916,7 +918,9 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -916,7 +918,9 @@ class AutoPipelineForInpainting(ConfigMixin):
if kwargs["controlnet"] is not None: if kwargs["controlnet"] is not None:
inpainting_cls = _get_task_class( inpainting_cls = _get_task_class(
AUTO_INPAINT_PIPELINES_MAPPING, AUTO_INPAINT_PIPELINES_MAPPING,
inpainting_cls.__name__.replace("InpaintPipeline", "ControlNetInpaintPipeline"), inpainting_cls.__name__.replace("ControlNet", "").replace(
"InpaintPipeline", "ControlNetInpaintPipeline"
),
) )
else: else:
inpainting_cls = _get_task_class( inpainting_cls = _get_task_class(
......
...@@ -156,6 +156,54 @@ class AutoPipelineFastTest(unittest.TestCase): ...@@ -156,6 +156,54 @@ class AutoPipelineFastTest(unittest.TestCase):
assert pipe_inpaint.__class__.__name__ == "StableDiffusionInpaintPipeline" assert pipe_inpaint.__class__.__name__ == "StableDiffusionInpaintPipeline"
assert "controlnet" not in pipe_inpaint.components assert "controlnet" not in pipe_inpaint.components
# testing `from_pipe` for text2img controlnet
## 1. from a different controlnet pipe, without controlnet argument
pipe_control_text2img = AutoPipelineForText2Image.from_pipe(pipe_control_img2img)
assert pipe_control_text2img.__class__.__name__ == "StableDiffusionControlNetPipeline"
assert "controlnet" in pipe_control_text2img.components
## 2. from a different controlnet pipe, with controlnet argument
pipe_control_text2img = AutoPipelineForText2Image.from_pipe(pipe_control_img2img, controlnet=controlnet)
assert pipe_control_text2img.__class__.__name__ == "StableDiffusionControlNetPipeline"
assert "controlnet" in pipe_control_text2img.components
## 3. from same controlnet pipeline class, with a different controlnet component
pipe_control_text2img = AutoPipelineForText2Image.from_pipe(pipe_control_text2img, controlnet=controlnet)
assert pipe_control_text2img.__class__.__name__ == "StableDiffusionControlNetPipeline"
assert "controlnet" in pipe_control_text2img.components
# testing from_pipe for inpainting
## 1. from a different controlnet pipeline class
pipe_control_inpaint = AutoPipelineForInpainting.from_pipe(pipe_control_img2img)
assert pipe_control_inpaint.__class__.__name__ == "StableDiffusionControlNetInpaintPipeline"
assert "controlnet" in pipe_control_inpaint.components
## from a different controlnet pipe, with a different controlnet
pipe_control_inpaint = AutoPipelineForInpainting.from_pipe(pipe_control_img2img, controlnet=controlnet)
assert pipe_control_inpaint.__class__.__name__ == "StableDiffusionControlNetInpaintPipeline"
assert "controlnet" in pipe_control_inpaint.components
## from same controlnet pipe, with a different controlnet
pipe_control_inpaint = AutoPipelineForInpainting.from_pipe(pipe_control_inpaint, controlnet=controlnet)
assert pipe_control_inpaint.__class__.__name__ == "StableDiffusionControlNetInpaintPipeline"
assert "controlnet" in pipe_control_inpaint.components
# testing from_pipe from img2img controlnet
## from a different controlnet pipe, without controlnet argument
pipe_control_img2img = AutoPipelineForImage2Image.from_pipe(pipe_control_text2img)
assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
assert "controlnet" in pipe_control_img2img.components
# from a different controlnet pipe, with a different controlnet component
pipe_control_img2img = AutoPipelineForImage2Image.from_pipe(pipe_control_text2img, controlnet=controlnet)
assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
assert "controlnet" in pipe_control_img2img.components
# from same controlnet pipeline class, with a different controlnet
pipe_control_img2img = AutoPipelineForImage2Image.from_pipe(pipe_control_img2img, controlnet=controlnet)
assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
assert "controlnet" in pipe_control_img2img.components
@slow @slow
class AutoPipelineIntegrationTest(unittest.TestCase): class AutoPipelineIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment