Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d8b6f5d0
Unverified
Commit
d8b6f5d0
authored
Sep 01, 2023
by
YiYi Xu
Committed by
GitHub
Sep 01, 2023
Browse files
support AutoPipeline.from_pipe between a pipeline and its ControlNet pipeline counterpart (#4861)
add
parent
30a5acc3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
84 additions
and
0 deletions
+84
-0
src/diffusers/pipelines/auto_pipeline.py
src/diffusers/pipelines/auto_pipeline.py
+36
-0
tests/pipelines/test_pipelines_auto.py
tests/pipelines/test_pipelines_auto.py
+48
-0
No files found.
src/diffusers/pipelines/auto_pipeline.py
View file @
d8b6f5d0
...
@@ -366,6 +366,18 @@ class AutoPipelineForText2Image(ConfigMixin):
...
@@ -366,6 +366,18 @@ class AutoPipelineForText2Image(ConfigMixin):
# derive the pipeline class to instantiate
# derive the pipeline class to instantiate
text_2_image_cls
=
_get_task_class
(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING
,
original_cls_name
)
text_2_image_cls
=
_get_task_class
(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING
,
original_cls_name
)
if
"controlnet"
in
kwargs
:
if
kwargs
[
"controlnet"
]
is
not
None
:
text_2_image_cls
=
_get_task_class
(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING
,
text_2_image_cls
.
__name__
.
replace
(
"Pipeline"
,
"ControlNetPipeline"
),
)
else
:
text_2_image_cls
=
_get_task_class
(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING
,
text_2_image_cls
.
__name__
.
replace
(
"ControlNetPipeline"
,
"Pipeline"
),
)
# define expected module and optional kwargs given the pipeline signature
# define expected module and optional kwargs given the pipeline signature
expected_modules
,
optional_kwargs
=
_get_signature_keys
(
text_2_image_cls
)
expected_modules
,
optional_kwargs
=
_get_signature_keys
(
text_2_image_cls
)
...
@@ -631,6 +643,18 @@ class AutoPipelineForImage2Image(ConfigMixin):
...
@@ -631,6 +643,18 @@ class AutoPipelineForImage2Image(ConfigMixin):
# derive the pipeline class to instantiate
# derive the pipeline class to instantiate
image_2_image_cls
=
_get_task_class
(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
,
original_cls_name
)
image_2_image_cls
=
_get_task_class
(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
,
original_cls_name
)
if
"controlnet"
in
kwargs
:
if
kwargs
[
"controlnet"
]
is
not
None
:
image_2_image_cls
=
_get_task_class
(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
,
image_2_image_cls
.
__name__
.
replace
(
"Img2ImgPipeline"
,
"ControlNetImg2ImgPipeline"
),
)
else
:
image_2_image_cls
=
_get_task_class
(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
,
image_2_image_cls
.
__name__
.
replace
(
"ControlNetImg2ImgPipeline"
,
"Img2ImgPipeline"
),
)
# define expected module and optional kwargs given the pipeline signature
# define expected module and optional kwargs given the pipeline signature
expected_modules
,
optional_kwargs
=
_get_signature_keys
(
image_2_image_cls
)
expected_modules
,
optional_kwargs
=
_get_signature_keys
(
image_2_image_cls
)
...
@@ -894,6 +918,18 @@ class AutoPipelineForInpainting(ConfigMixin):
...
@@ -894,6 +918,18 @@ class AutoPipelineForInpainting(ConfigMixin):
# derive the pipeline class to instantiate
# derive the pipeline class to instantiate
inpainting_cls
=
_get_task_class
(
AUTO_INPAINT_PIPELINES_MAPPING
,
original_cls_name
)
inpainting_cls
=
_get_task_class
(
AUTO_INPAINT_PIPELINES_MAPPING
,
original_cls_name
)
if
"controlnet"
in
kwargs
:
if
kwargs
[
"controlnet"
]
is
not
None
:
inpainting_cls
=
_get_task_class
(
AUTO_INPAINT_PIPELINES_MAPPING
,
inpainting_cls
.
__name__
.
replace
(
"InpaintPipeline"
,
"ControlNetInpaintPipeline"
),
)
else
:
inpainting_cls
=
_get_task_class
(
AUTO_INPAINT_PIPELINES_MAPPING
,
inpainting_cls
.
__name__
.
replace
(
"ControlNetInpaintPipeline"
,
"InpaintPipeline"
),
)
# define expected module and optional kwargs given the pipeline signature
# define expected module and optional kwargs given the pipeline signature
expected_modules
,
optional_kwargs
=
_get_signature_keys
(
inpainting_cls
)
expected_modules
,
optional_kwargs
=
_get_signature_keys
(
inpainting_cls
)
...
...
tests/pipelines/test_pipelines_auto.py
View file @
d8b6f5d0
...
@@ -108,6 +108,54 @@ class AutoPipelineFastTest(unittest.TestCase):
...
@@ -108,6 +108,54 @@ class AutoPipelineFastTest(unittest.TestCase):
shutil
.
rmtree
(
tmpdirname
.
parent
.
parent
)
shutil
.
rmtree
(
tmpdirname
.
parent
.
parent
)
def
test_from_pipe_controlnet_text2img
(
self
):
pipe
=
AutoPipelineForText2Image
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
)
controlnet
=
ControlNetModel
.
from_pretrained
(
"hf-internal-testing/tiny-controlnet"
)
pipe
=
AutoPipelineForText2Image
.
from_pipe
(
pipe
,
controlnet
=
controlnet
)
assert
pipe
.
__class__
.
__name__
==
"StableDiffusionControlNetPipeline"
assert
"controlnet"
in
pipe
.
components
pipe
=
AutoPipelineForText2Image
.
from_pipe
(
pipe
,
controlnet
=
None
)
assert
pipe
.
__class__
.
__name__
==
"StableDiffusionPipeline"
assert
"controlnet"
not
in
pipe
.
components
def
test_from_pipe_controlnet_img2img
(
self
):
pipe
=
AutoPipelineForImage2Image
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
)
controlnet
=
ControlNetModel
.
from_pretrained
(
"hf-internal-testing/tiny-controlnet"
)
pipe
=
AutoPipelineForImage2Image
.
from_pipe
(
pipe
,
controlnet
=
controlnet
)
assert
pipe
.
__class__
.
__name__
==
"StableDiffusionControlNetImg2ImgPipeline"
assert
"controlnet"
in
pipe
.
components
pipe
=
AutoPipelineForImage2Image
.
from_pipe
(
pipe
,
controlnet
=
None
)
assert
pipe
.
__class__
.
__name__
==
"StableDiffusionImg2ImgPipeline"
assert
"controlnet"
not
in
pipe
.
components
def
test_from_pipe_controlnet_inpaint
(
self
):
pipe
=
AutoPipelineForInpainting
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
)
controlnet
=
ControlNetModel
.
from_pretrained
(
"hf-internal-testing/tiny-controlnet"
)
pipe
=
AutoPipelineForInpainting
.
from_pipe
(
pipe
,
controlnet
=
controlnet
)
assert
pipe
.
__class__
.
__name__
==
"StableDiffusionControlNetInpaintPipeline"
assert
"controlnet"
in
pipe
.
components
pipe
=
AutoPipelineForInpainting
.
from_pipe
(
pipe
,
controlnet
=
None
)
assert
pipe
.
__class__
.
__name__
==
"StableDiffusionInpaintPipeline"
assert
"controlnet"
not
in
pipe
.
components
def
test_from_pipe_controlnet_new_task
(
self
):
pipe_text2img
=
AutoPipelineForText2Image
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
)
controlnet
=
ControlNetModel
.
from_pretrained
(
"hf-internal-testing/tiny-controlnet"
)
pipe_control_img2img
=
AutoPipelineForImage2Image
.
from_pipe
(
pipe_text2img
,
controlnet
=
controlnet
)
assert
pipe_control_img2img
.
__class__
.
__name__
==
"StableDiffusionControlNetImg2ImgPipeline"
assert
"controlnet"
in
pipe_control_img2img
.
components
pipe_inpaint
=
AutoPipelineForInpainting
.
from_pipe
(
pipe_control_img2img
,
controlnet
=
None
)
assert
pipe_inpaint
.
__class__
.
__name__
==
"StableDiffusionInpaintPipeline"
assert
"controlnet"
not
in
pipe_inpaint
.
components
@
slow
@
slow
class
AutoPipelineIntegrationTest
(
unittest
.
TestCase
):
class
AutoPipelineIntegrationTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment