Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4450d26b
Unverified
Commit
4450d26b
authored
Dec 19, 2024
by
hlky
Committed by
GitHub
Dec 18, 2024
Browse files
Add Flux Control to AutoPipeline (#10292)
parent
f781b8c3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
5 deletions
+32
-5
src/diffusers/pipelines/auto_pipeline.py
src/diffusers/pipelines/auto_pipeline.py
+32
-5
No files found.
src/diffusers/pipelines/auto_pipeline.py
View file @
4450d26b
...
...
@@ -35,9 +35,12 @@ from .controlnet import (
)
from
.deepfloyd_if
import
IFImg2ImgPipeline
,
IFInpaintingPipeline
,
IFPipeline
from
.flux
import
(
FluxControlImg2ImgPipeline
,
FluxControlInpaintPipeline
,
FluxControlNetImg2ImgPipeline
,
FluxControlNetInpaintPipeline
,
FluxControlNetPipeline
,
FluxControlPipeline
,
FluxImg2ImgPipeline
,
FluxInpaintPipeline
,
FluxPipeline
,
...
...
@@ -125,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
(
"pixart-sigma-pag"
,
PixArtSigmaPAGPipeline
),
(
"auraflow"
,
AuraFlowPipeline
),
(
"flux"
,
FluxPipeline
),
(
"flux-control"
,
FluxControlPipeline
),
(
"flux-controlnet"
,
FluxControlNetPipeline
),
(
"lumina"
,
LuminaText2ImgPipeline
),
(
"cogview3"
,
CogView3PlusPipeline
),
...
...
@@ -150,6 +154,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
(
"lcm"
,
LatentConsistencyModelImg2ImgPipeline
),
(
"flux"
,
FluxImg2ImgPipeline
),
(
"flux-controlnet"
,
FluxControlNetImg2ImgPipeline
),
(
"flux-control"
,
FluxControlImg2ImgPipeline
),
]
)
...
...
@@ -168,6 +173,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
(
"stable-diffusion-xl-pag"
,
StableDiffusionXLPAGInpaintPipeline
),
(
"flux"
,
FluxInpaintPipeline
),
(
"flux-controlnet"
,
FluxControlNetInpaintPipeline
),
(
"flux-control"
,
FluxControlInpaintPipeline
),
(
"stable-diffusion-pag"
,
StableDiffusionPAGInpaintPipeline
),
]
)
...
...
@@ -401,16 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
config
=
cls
.
load_config
(
pretrained_model_or_path
,
**
load_config_kwargs
)
orig_class_name
=
config
[
"_class_name"
]
if
"ControlPipeline"
in
orig_class_name
:
to_replace
=
"ControlPipeline"
else
:
to_replace
=
"Pipeline"
if
"controlnet"
in
kwargs
:
if
isinstance
(
kwargs
[
"controlnet"
],
ControlNetUnionModel
):
orig_class_name
=
config
[
"_class_name"
].
replace
(
"Pipeline"
,
"ControlNetUnionPipeline"
)
orig_class_name
=
config
[
"_class_name"
].
replace
(
to_replace
,
"ControlNetUnionPipeline"
)
else
:
orig_class_name
=
config
[
"_class_name"
].
replace
(
"Pipeline"
,
"ControlNetPipeline"
)
orig_class_name
=
config
[
"_class_name"
].
replace
(
to_replace
,
"ControlNetPipeline"
)
if
"enable_pag"
in
kwargs
:
enable_pag
=
kwargs
.
pop
(
"enable_pag"
)
if
enable_pag
:
orig_class_name
=
orig_class_name
.
replace
(
"Pipeline"
,
"PAGPipeline"
)
orig_class_name
=
orig_class_name
.
replace
(
to_replace
,
"PAGPipeline"
)
text_2_image_cls
=
_get_task_class
(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING
,
orig_class_name
)
...
...
@@ -694,8 +704,14 @@ class AutoPipelineForImage2Image(ConfigMixin):
# the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint)
to_replace
=
"Img2ImgPipeline"
if
"Img2Img"
in
config
[
"_class_name"
]
else
"Pipeline"
if
"Img2Img"
in
orig_class_name
:
to_replace
=
"Img2ImgPipeline"
elif
"ControlPipeline"
in
orig_class_name
:
to_replace
=
"ControlPipeline"
else
:
to_replace
=
"Pipeline"
if
"controlnet"
in
kwargs
:
if
isinstance
(
kwargs
[
"controlnet"
],
ControlNetUnionModel
):
...
...
@@ -707,6 +723,9 @@ class AutoPipelineForImage2Image(ConfigMixin):
if
enable_pag
:
orig_class_name
=
orig_class_name
.
replace
(
to_replace
,
"PAG"
+
to_replace
)
if
to_replace
==
"ControlPipeline"
:
orig_class_name
=
orig_class_name
.
replace
(
to_replace
,
"ControlImg2ImgPipeline"
)
image_2_image_cls
=
_get_task_class
(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
,
orig_class_name
)
kwargs
=
{
**
load_config_kwargs
,
**
kwargs
}
...
...
@@ -994,8 +1013,14 @@ class AutoPipelineForInpainting(ConfigMixin):
# The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint)
to_replace
=
"InpaintPipeline"
if
"Inpaint"
in
config
[
"_class_name"
]
else
"Pipeline"
if
"Inpaint"
in
orig_class_name
:
to_replace
=
"InpaintPipeline"
elif
"ControlPipeline"
in
orig_class_name
:
to_replace
=
"ControlPipeline"
else
:
to_replace
=
"Pipeline"
if
"controlnet"
in
kwargs
:
if
isinstance
(
kwargs
[
"controlnet"
],
ControlNetUnionModel
):
...
...
@@ -1006,6 +1031,8 @@ class AutoPipelineForInpainting(ConfigMixin):
enable_pag
=
kwargs
.
pop
(
"enable_pag"
)
if
enable_pag
:
orig_class_name
=
orig_class_name
.
replace
(
to_replace
,
"PAG"
+
to_replace
)
if
to_replace
==
"ControlPipeline"
:
orig_class_name
=
orig_class_name
.
replace
(
to_replace
,
"ControlInpaintPipeline"
)
inpainting_cls
=
_get_task_class
(
AUTO_INPAINT_PIPELINES_MAPPING
,
orig_class_name
)
kwargs
=
{
**
load_config_kwargs
,
**
kwargs
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment