Unverified Commit 4450d26b authored by hlky's avatar hlky Committed by GitHub
Browse files

Add Flux Control to AutoPipeline (#10292)

parent f781b8c3
...@@ -35,9 +35,12 @@ from .controlnet import ( ...@@ -35,9 +35,12 @@ from .controlnet import (
) )
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import ( from .flux import (
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline, FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline, FluxControlNetInpaintPipeline,
FluxControlNetPipeline, FluxControlNetPipeline,
FluxControlPipeline,
FluxImg2ImgPipeline, FluxImg2ImgPipeline,
FluxInpaintPipeline, FluxInpaintPipeline,
FluxPipeline, FluxPipeline,
...@@ -125,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -125,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("pixart-sigma-pag", PixArtSigmaPAGPipeline), ("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline), ("auraflow", AuraFlowPipeline),
("flux", FluxPipeline), ("flux", FluxPipeline),
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline), ("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline), ("lumina", LuminaText2ImgPipeline),
("cogview3", CogView3PlusPipeline), ("cogview3", CogView3PlusPipeline),
...@@ -150,6 +154,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -150,6 +154,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("lcm", LatentConsistencyModelImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline), ("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline), ("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
] ]
) )
...@@ -168,6 +173,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( ...@@ -168,6 +173,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline), ("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline), ("flux-controlnet", FluxControlNetInpaintPipeline),
("flux-control", FluxControlInpaintPipeline),
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline), ("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
] ]
) )
...@@ -401,16 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -401,16 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"] orig_class_name = config["_class_name"]
if "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs: if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel): if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline") orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
else: else:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
if "enable_pag" in kwargs: if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag") enable_pag = kwargs.pop("enable_pag")
if enable_pag: if enable_pag:
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
...@@ -694,8 +704,14 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -694,8 +704,14 @@ class AutoPipelineForImage2Image(ConfigMixin):
# the `orig_class_name` can be: # the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint) # `- *Pipeline` (for regular text-to-image checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint) # `- *Img2ImgPipeline` (for refiner checkpoint)
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" if "Img2Img" in orig_class_name:
to_replace = "Img2ImgPipeline"
elif "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs: if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel): if isinstance(kwargs["controlnet"], ControlNetUnionModel):
...@@ -707,6 +723,9 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -707,6 +723,9 @@ class AutoPipelineForImage2Image(ConfigMixin):
if enable_pag: if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
if to_replace == "ControlPipeline":
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs} kwargs = {**load_config_kwargs, **kwargs}
...@@ -994,8 +1013,14 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -994,8 +1013,14 @@ class AutoPipelineForInpainting(ConfigMixin):
# The `orig_class_name`` can be: # The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint) # `- *InpaintPipeline` (for inpaint-specific checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint) # - or *Pipeline (for regular text-to-image checkpoint)
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" if "Inpaint" in orig_class_name:
to_replace = "InpaintPipeline"
elif "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs: if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel): if isinstance(kwargs["controlnet"], ControlNetUnionModel):
...@@ -1006,6 +1031,8 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -1006,6 +1031,8 @@ class AutoPipelineForInpainting(ConfigMixin):
enable_pag = kwargs.pop("enable_pag") enable_pag = kwargs.pop("enable_pag")
if enable_pag: if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
if to_replace == "ControlPipeline":
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs} kwargs = {**load_config_kwargs, **kwargs}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment