Unverified Commit ff9a3876 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[core] add modular support for Flux I2I (#12086)

* start

* encoder.

* up

* up

* up

* up

* up

* up
parent 03c3f69a
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import List, Optional, Union from typing import Any, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ...models import AutoencoderKL
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging from ...utils import logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
...@@ -103,6 +104,62 @@ def calculate_shift( ...@@ -103,6 +104,62 @@ def calculate_shift(
return mu return mu
# Adapted from the original implementation.
def prepare_latents_img2img(
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
latent_channels = vae.config.latent_channels
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
image = image.to(device=device, dtype=dtype)
if image.shape[1] != latent_channels:
image_latents = _encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = scheduler.scale_noise(image_latents, timestep, noise)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.permute(0, 2, 4, 1, 3, 5)
...@@ -125,6 +182,55 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): ...@@ -125,6 +182,55 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
return latent_image_ids.to(device=device, dtype=dtype) return latent_image_ids.to(device=device, dtype=dtype)
# Cannot use "# Copied from" because it introduces weird indentation errors.
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator)
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents
def _get_initial_timesteps_and_optionals(
transformer,
scheduler,
batch_size,
height,
width,
vae_scale_factor,
num_inference_steps,
guidance_scale,
sigmas,
device,
):
image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
sigmas = None
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
if transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
else:
guidance = None
return timesteps, num_inference_steps, sigmas, guidance
class FluxInputStep(PipelineBlock): class FluxInputStep(PipelineBlock):
model_name = "flux" model_name = "flux"
...@@ -234,18 +340,20 @@ class FluxSetTimestepsStep(PipelineBlock): ...@@ -234,18 +340,20 @@ class FluxSetTimestepsStep(PipelineBlock):
InputParam("timesteps"), InputParam("timesteps"),
InputParam("sigmas"), InputParam("sigmas"),
InputParam("guidance_scale", default=3.5), InputParam("guidance_scale", default=3.5),
InputParam("latents", type_hint=torch.Tensor), InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
] ]
@property @property
def intermediate_inputs(self) -> List[str]: def intermediate_inputs(self) -> List[str]:
return [ return [
InputParam( InputParam(
"latents", "batch_size",
required=True, required=True,
type_hint=torch.Tensor, type_hint=int,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
) ),
] ]
@property @property
...@@ -264,34 +372,127 @@ class FluxSetTimestepsStep(PipelineBlock): ...@@ -264,34 +372,127 @@ class FluxSetTimestepsStep(PipelineBlock):
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state) block_state = self.get_block_state(state)
block_state.device = components._execution_device block_state.device = components._execution_device
scheduler = components.scheduler
latents = block_state.latents scheduler = components.scheduler
image_seq_len = latents.shape[1] transformer = components.transformer
num_inference_steps = block_state.num_inference_steps batch_size = block_state.batch_size * block_state.num_images_per_prompt
sigmas = block_state.sigmas timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas transformer,
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: scheduler,
sigmas = None batch_size,
block_state.height,
block_state.width,
components.vae_scale_factor,
block_state.num_inference_steps,
block_state.guidance_scale,
block_state.sigmas,
block_state.device,
)
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
block_state.sigmas = sigmas block_state.sigmas = sigmas
mu = calculate_shift( block_state.guidance = guidance
image_seq_len,
scheduler.config.get("base_image_seq_len", 256), self.set_block_state(state, block_state)
scheduler.config.get("max_image_seq_len", 4096), return components, state
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
class FluxImg2ImgSetTimestepsStep(PipelineBlock):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("strength", default=0.6),
InputParam("guidance_scale", default=3.5),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam(
"latent_timestep",
type_hint=torch.Tensor,
description="The timestep that represents the initial noise level for image-to-image generation",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler
def get_timesteps(scheduler, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
if hasattr(scheduler, "set_begin_index"):
scheduler.set_begin_index(t_start * scheduler.order)
return timesteps, num_inference_steps - t_start
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
scheduler = components.scheduler
transformer = components.transformer
batch_size = block_state.batch_size * block_state.num_images_per_prompt
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
transformer,
scheduler,
batch_size,
block_state.height,
block_state.width,
components.vae_scale_factor,
block_state.num_inference_steps,
block_state.guidance_scale,
block_state.sigmas,
block_state.device,
) )
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = self.get_timesteps(
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu scheduler, num_inference_steps, block_state.strength, block_state.device
) )
if components.transformer.config.guidance_embeds: block_state.timesteps = timesteps
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32) block_state.num_inference_steps = num_inference_steps
guidance = guidance.expand(latents.shape[0]) block_state.sigmas = sigmas
else:
guidance = None
block_state.guidance = guidance block_state.guidance = guidance
block_state.latent_timestep = timesteps[:1].repeat(batch_size)
self.set_block_state(state, block_state) self.set_block_state(state, block_state)
return components, state return components, state
...@@ -305,7 +506,7 @@ class FluxPrepareLatentsStep(PipelineBlock): ...@@ -305,7 +506,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
@property @property
def description(self) -> str: def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process" return "Prepare latents step that prepares the latents for the text-to-image generation process"
@property @property
def inputs(self) -> List[InputParam]: def inputs(self) -> List[InputParam]:
...@@ -402,10 +603,10 @@ class FluxPrepareLatentsStep(PipelineBlock): ...@@ -402,10 +603,10 @@ class FluxPrepareLatentsStep(PipelineBlock):
block_state.num_channels_latents = components.num_channels_latents block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state) self.check_inputs(components, block_state)
batch_size = block_state.batch_size * block_state.num_images_per_prompt
block_state.latents, block_state.latent_image_ids = self.prepare_latents( block_state.latents, block_state.latent_image_ids = self.prepare_latents(
components, components,
block_state.batch_size * block_state.num_images_per_prompt, batch_size,
block_state.num_channels_latents, block_state.num_channels_latents,
block_state.height, block_state.height,
block_state.width, block_state.width,
...@@ -418,3 +619,95 @@ class FluxPrepareLatentsStep(PipelineBlock): ...@@ -418,3 +619,95 @@ class FluxPrepareLatentsStep(PipelineBlock):
self.set_block_state(state, block_state) self.set_block_state(state, block_state)
return components, state return components, state
class FluxImg2ImgPrepareLatentsStep(PipelineBlock):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return "Step that prepares the latents for the image-to-image generation process"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"image_latents",
required=True,
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
),
InputParam(
"latent_timestep",
required=True,
type_hint=torch.Tensor,
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
),
OutputParam(
"latent_image_ids",
type_hint=torch.Tensor,
description="IDs computed from the image sequence needed for RoPE",
),
]
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
block_state.num_channels_latents = components.num_channels_latents
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
# TODO: implement `check_inputs`
batch_size = block_state.batch_size * block_state.num_images_per_prompt
if block_state.latents is None:
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
components.vae,
components.scheduler,
block_state.image_latents,
block_state.latent_timestep,
batch_size,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
)
self.set_block_state(state, block_state)
return components, state
...@@ -226,5 +226,5 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper): ...@@ -226,5 +226,5 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `FluxLoopDenoiser`\n" " - `FluxLoopDenoiser`\n"
" - `FluxLoopAfterDenoiser`\n" " - `FluxLoopAfterDenoiser`\n"
"This block supports text2image tasks." "This block supports both text2image and img2img tasks."
) )
...@@ -19,7 +19,10 @@ import regex as re ...@@ -19,7 +19,10 @@ import regex as re
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
...@@ -50,6 +53,110 @@ def prompt_clean(text): ...@@ -50,6 +53,110 @@ def prompt_clean(text):
return text return text
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class FluxVaeEncoderStep(PipelineBlock):
model_name = "flux"
@property
def description(self) -> str:
return "Vae Encoder step that encode the input image into a latent representation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [InputParam("image", required=True), InputParam("height"), InputParam("width")]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
"preprocess_kwargs",
type_hint=Optional[dict],
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation",
)
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator)
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(
components.vae, image=block_state.image, generator=block_state.generator
)
self.set_block_state(state, block_state)
return components, state
class FluxTextEncoderStep(PipelineBlock): class FluxTextEncoderStep(PipelineBlock):
model_name = "flux" model_name = "flux"
...@@ -297,7 +404,7 @@ class FluxTextEncoderStep(PipelineBlock): ...@@ -297,7 +404,7 @@ class FluxTextEncoderStep(PipelineBlock):
prompt_embeds=None, prompt_embeds=None,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
device=block_state.device, device=block_state.device,
num_images_per_prompt=1, # hardcoded for now. num_images_per_prompt=1, # TODO: hardcoded for now.
lora_scale=block_state.text_encoder_lora_scale, lora_scale=block_state.text_encoder_lora_scale,
) )
......
...@@ -15,16 +15,38 @@ ...@@ -15,16 +15,38 @@
from ...utils import logging from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict from ..modular_pipeline_utils import InsertableDict
from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep from .before_denoise import (
FluxImg2ImgPrepareLatentsStep,
FluxImg2ImgSetTimestepsStep,
FluxInputStep,
FluxPrepareLatentsStep,
FluxSetTimestepsStep,
)
from .decoders import FluxDecodeStep from .decoders import FluxDecodeStep
from .denoise import FluxDenoiseStep from .denoise import FluxDenoiseStep
from .encoders import FluxTextEncoderStep from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# before_denoise: text2vid # vae encoder (run before before_denoise)
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [FluxVaeEncoderStep]
block_names = ["img2img"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for img2img tasks.\n"
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if `image` is provided, step will be skipped."
)
# before_denoise: text2img, img2img
class FluxBeforeDenoiseStep(SequentialPipelineBlocks): class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [ block_classes = [
FluxInputStep, FluxInputStep,
...@@ -44,11 +66,27 @@ class FluxBeforeDenoiseStep(SequentialPipelineBlocks): ...@@ -44,11 +66,27 @@ class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
) )
# before_denoise: all task (text2vid,) # before_denoise: img2img
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
block_names = ["input", "set_timesteps", "prepare_latents"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
)
# before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks): class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxBeforeDenoiseStep] block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
block_names = ["text2image"] block_names = ["text2image", "img2img"]
block_trigger_inputs = [None] block_trigger_inputs = [None, "image_latents"]
@property @property
def description(self): def description(self):
...@@ -56,6 +94,7 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks): ...@@ -56,6 +94,7 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
"Before denoise step that prepare the inputs for the denoise step.\n" "Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n" + "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n" + " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
) )
...@@ -69,8 +108,8 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks): ...@@ -69,8 +108,8 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
def description(self) -> str: def description(self) -> str:
return ( return (
"Denoise step that iteratively denoise the latents. " "Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image tasks." "This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image tasks." " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
) )
...@@ -82,19 +121,26 @@ class FluxAutoDecodeStep(AutoPipelineBlocks): ...@@ -82,19 +121,26 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
@property @property
def description(self): def description(self):
return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`" return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
# text2image # text2image
class FluxAutoBlocks(SequentialPipelineBlocks): class FluxAutoBlocks(SequentialPipelineBlocks):
block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep] block_classes = [
block_names = ["text_encoder", "before_denoise", "denoise", "decoder"] FluxTextEncoderStep,
FluxAutoVaeEncoderStep,
FluxAutoBeforeDenoiseStep,
FluxAutoDenoiseStep,
FluxAutoDecodeStep,
]
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
@property @property
def description(self): def description(self):
return ( return (
"Auto Modular pipeline for text-to-image using Flux.\n" "Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
+ "- for text-to-image generation, all you need to provide is `prompt`" + "- for text-to-image generation, all you need to provide is `prompt`\n"
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`"
) )
...@@ -102,19 +148,29 @@ TEXT2IMAGE_BLOCKS = InsertableDict( ...@@ -102,19 +148,29 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
[ [
("text_encoder", FluxTextEncoderStep), ("text_encoder", FluxTextEncoderStep),
("input", FluxInputStep), ("input", FluxInputStep),
("prepare_latents", FluxPrepareLatentsStep),
# Setting it after preparation of latents because we rely on `latents`
# to calculate `img_seq_len` for `shift`.
("set_timesteps", FluxSetTimestepsStep), ("set_timesteps", FluxSetTimestepsStep),
("prepare_latents", FluxPrepareLatentsStep),
("denoise", FluxDenoiseStep), ("denoise", FluxDenoiseStep),
("decode", FluxDecodeStep), ("decode", FluxDecodeStep),
] ]
) )
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxVaeEncoderStep),
("input", FluxInputStep),
("set_timesteps", FluxImg2ImgSetTimestepsStep),
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
("denoise", FluxDenoiseStep),
("decode", FluxDecodeStep),
]
)
AUTO_BLOCKS = InsertableDict( AUTO_BLOCKS = InsertableDict(
[ [
("text_encoder", FluxTextEncoderStep), ("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxAutoVaeEncoderStep),
("before_denoise", FluxAutoBeforeDenoiseStep), ("before_denoise", FluxAutoBeforeDenoiseStep),
("denoise", FluxAutoDenoiseStep), ("denoise", FluxAutoDenoiseStep),
("decode", FluxAutoDecodeStep), ("decode", FluxAutoDecodeStep),
...@@ -122,4 +178,4 @@ AUTO_BLOCKS = InsertableDict( ...@@ -122,4 +178,4 @@ AUTO_BLOCKS = InsertableDict(
) )
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS} ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from ...loaders import FluxLoraLoaderMixin from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...utils import logging from ...utils import logging
from ..modular_pipeline import ModularPipeline from ..modular_pipeline import ModularPipeline
...@@ -21,7 +21,7 @@ from ..modular_pipeline import ModularPipeline ...@@ -21,7 +21,7 @@ from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin): class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin):
""" """
A ModularPipeline for Flux. A ModularPipeline for Flux.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment