Unverified Commit 2dc31677 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Align Flux modular more and more with Qwen modular (#12445)

* start

* fix

* up
parent 1066de8c
...@@ -76,18 +76,17 @@ class FluxLoopDenoiser(ModularPipelineBlocks): ...@@ -76,18 +76,17 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
description="Pooled prompt embeddings", description="Pooled prompt embeddings",
), ),
InputParam( InputParam(
"text_ids", "txt_ids",
required=True, required=True,
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="IDs computed from text sequence needed for RoPE", description="IDs computed from text sequence needed for RoPE",
), ),
InputParam( InputParam(
"latent_image_ids", "img_ids",
required=True, required=True,
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="IDs computed from image sequence needed for RoPE", description="IDs computed from image sequence needed for RoPE",
), ),
# TODO: guidance
] ]
@torch.no_grad() @torch.no_grad()
...@@ -101,8 +100,8 @@ class FluxLoopDenoiser(ModularPipelineBlocks): ...@@ -101,8 +100,8 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
encoder_hidden_states=block_state.prompt_embeds, encoder_hidden_states=block_state.prompt_embeds,
pooled_projections=block_state.pooled_prompt_embeds, pooled_projections=block_state.pooled_prompt_embeds,
joint_attention_kwargs=block_state.joint_attention_kwargs, joint_attention_kwargs=block_state.joint_attention_kwargs,
txt_ids=block_state.text_ids, txt_ids=block_state.txt_ids,
img_ids=block_state.latent_image_ids, img_ids=block_state.img_ids,
return_dict=False, return_dict=False,
)[0] )[0]
block_state.noise_pred = noise_pred block_state.noise_pred = noise_pred
...@@ -195,9 +194,6 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks): ...@@ -195,9 +194,6 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
block_state.num_warmup_steps = max( block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
) )
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
components.scheduler.set_begin_index(0)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps): for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t) components, block_state = self.loop_step(components, block_state, i=i, t=t)
......
...@@ -25,7 +25,7 @@ from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin ...@@ -25,7 +25,7 @@ from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline from .modular_pipeline import FluxModularPipeline
...@@ -67,89 +67,148 @@ def retrieve_latents( ...@@ -67,89 +67,148 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output") raise AttributeError("Could not access latents of provided encoder_output")
class FluxVaeEncoderStep(ModularPipelineBlocks): def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"):
model_name = "flux" if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents
class FluxProcessImagesInputStep(ModularPipelineBlocks):
model_name = "Flux"
@property @property
def description(self) -> str: def description(self) -> str:
return "Vae Encoder step that encode the input image into a latent representation" return "Image Preprocess step. Resizing is needed in Flux Kontext (will be implemented later.)"
@property @property
def expected_components(self) -> List[ComponentSpec]: def expected_components(self) -> List[ComponentSpec]:
return [ return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec( ComponentSpec(
"image_processor", "image_processor",
VaeImageProcessor, VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}), config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config", default_creation_method="from_config",
), ),
] ]
@property @property
def inputs(self) -> List[InputParam]: def inputs(self) -> List[InputParam]:
return [ return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
"preprocess_kwargs",
type_hint=Optional[dict],
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
),
]
@property @property
def intermediate_outputs(self) -> List[OutputParam]: def intermediate_outputs(self) -> List[OutputParam]:
return [ return [
OutputParam( OutputParam(name="processed_image"),
"image_latents",
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation",
)
] ]
@staticmethod @staticmethod
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae def check_inputs(height, width, vae_scale_factor):
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): if height is not None and height % (vae_scale_factor * 2) != 0:
if isinstance(generator, list): raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) if width is not None and width % (vae_scale_factor * 2) != 0:
] raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
image_latents = torch.cat(image_latents, dim=0)
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
if block_state.resized_image is None and block_state.image is None:
raise ValueError("`resized_image` and `image` cannot be None at the same time")
if block_state.resized_image is None:
image = block_state.image
self.check_inputs(
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
)
height = block_state.height or components.default_height
width = block_state.width or components.default_width
else: else:
image_latents = retrieve_latents(vae.encode(image), generator=generator) width, height = block_state.resized_image[0].size
image = block_state.resized_image
block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width)
self.set_block_state(state, block_state)
return components, state
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
model_name = "flux"
def __init__(
self,
input_name: str = "processed_image",
output_name: str = "image_latents",
):
"""Initialize a VAE encoder step for converting images to latent representations.
Both the input and output names are configurable so this block can be configured to process to different image
inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
Args:
input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
Examples: "processed_image" or "processed_control_image"
output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
Examples: "image_latents" or "control_image_latents"
Examples:
# Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep()
# Custom input/output names for control image: # FluxImageVaeEncoderDynamicStep(
input_name="processed_control_image", output_name="control_image_latents"
)
"""
self._image_input_name = input_name
self._image_latents_output_name = output_name
super().__init__()
@property
def description(self) -> str:
return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
@property
def expected_components(self) -> List[ComponentSpec]:
components = [ComponentSpec("vae", AutoencoderKL)]
return components
@property
def inputs(self) -> List[InputParam]:
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
self._image_latents_output_name,
type_hint=torch.Tensor,
description="The latents representing the reference image",
)
]
@torch.no_grad() @torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state) block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess( device = components._execution_device
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs dtype = components.vae.dtype
)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0] image = getattr(block_state, self._image_input_name)
image = image.to(device=device, dtype=dtype)
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size) # Encode image into latents
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: image_latents = encode_vae_image(image=image, vae=components.vae, generator=block_state.generator)
raise ValueError( setattr(block_state, self._image_latents_output_name, image_latents)
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(
components.vae, image=block_state.image, generator=block_state.generator
)
self.set_block_state(state, block_state) self.set_block_state(state, block_state)
...@@ -161,7 +220,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -161,7 +220,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
@property @property
def description(self) -> str: def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the video generation" return "Text Encoder step that generate text_embeddings to guide the image generation"
@property @property
def expected_components(self) -> List[ComponentSpec]: def expected_components(self) -> List[ComponentSpec]:
...@@ -172,10 +231,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -172,10 +231,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
ComponentSpec("tokenizer_2", T5TokenizerFast), ComponentSpec("tokenizer_2", T5TokenizerFast),
] ]
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property @property
def inputs(self) -> List[InputParam]: def inputs(self) -> List[InputParam]:
return [ return [
...@@ -200,12 +255,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -200,12 +255,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation", description="pooled text embeddings used to guide the image generation",
), ),
OutputParam(
"text_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="ids from the text sequence for RoPE",
),
] ]
@staticmethod @staticmethod
...@@ -216,16 +265,10 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -216,16 +265,10 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
@staticmethod @staticmethod
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
components, components, prompt: Union[str, List[str]], max_sequence_length: int, device: torch.device
prompt: Union[str, List[str]],
num_images_per_prompt: int,
max_sequence_length: int,
device: torch.device,
): ):
dtype = components.text_encoder_2.dtype dtype = components.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(components, TextualInversionLoaderMixin): if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2) prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2)
...@@ -251,23 +294,11 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -251,23 +294,11 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds return prompt_embeds
@staticmethod @staticmethod
def _get_clip_prompt_embeds( def _get_clip_prompt_embeds(components, prompt: Union[str, List[str]], device: torch.device):
components,
prompt: Union[str, List[str]],
num_images_per_prompt: int,
device: torch.device,
):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(components, TextualInversionLoaderMixin): if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, components.tokenizer) prompt = components.maybe_convert_prompt(prompt, components.tokenizer)
...@@ -297,10 +328,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -297,10 +328,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds return prompt_embeds
@staticmethod @staticmethod
...@@ -309,34 +336,11 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -309,34 +336,11 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]], prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512, max_sequence_length: int = 512,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
): ):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or components._execution_device device = device or components._execution_device
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
...@@ -361,12 +365,10 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -361,12 +365,10 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
components, components,
prompt=prompt, prompt=prompt,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt,
) )
prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds( prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds(
components, components,
prompt=prompt_2, prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
device=device, device=device,
) )
...@@ -381,10 +383,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -381,10 +383,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale) unscale_lora_layers(components.text_encoder_2, lora_scale)
dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16 return prompt_embeds, pooled_prompt_embeds
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
@torch.no_grad() @torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
...@@ -400,14 +399,13 @@ class FluxTextEncoderStep(ModularPipelineBlocks): ...@@ -400,14 +399,13 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
if block_state.joint_attention_kwargs is not None if block_state.joint_attention_kwargs is not None
else None else None
) )
(block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt( block_state.prompt_embeds, block_state.pooled_prompt_embeds = self.encode_prompt(
components, components,
prompt=block_state.prompt, prompt=block_state.prompt,
prompt_2=None, prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
device=block_state.device, device=block_state.device,
num_images_per_prompt=1, # TODO: hardcoded for now.
max_sequence_length=block_state.max_sequence_length, max_sequence_length=block_state.max_sequence_length,
lora_scale=block_state.text_encoder_lora_scale, lora_scale=block_state.text_encoder_lora_scale,
) )
......
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from ...pipelines import FluxPipeline
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import InputParam, OutputParam
# TODO: consider making these common utilities for modular if they are not pipeline-specific.
from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size
from .modular_pipeline import FluxModularPipeline
class FluxTextInputStep(ModularPipelineBlocks):
model_name = "flux"
@property
def description(self) -> str:
return (
"Text input processing step that standardizes text embeddings for the pipeline.\n"
"This step:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"pooled_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
),
# TODO: support negative embeddings?
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="pooled text embeddings used to guide the image generation",
),
# TODO: support negative embeddings?
]
def check_inputs(self, components, block_state):
if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
raise ValueError(
"`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
f" {block_state.pooled_prompt_embeds.shape}."
)
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
# TODO: consider adding negative embeddings?
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
# Adapted from `QwenImageInputsDynamicStep`
class FluxInputsDynamicStep(ModularPipelineBlocks):
model_name = "flux"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
super().__init__()
@property
def description(self) -> str:
# Functionality section
summary_section = (
"Input processing step that:\n"
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
# Inputs info
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
# Placement guidance
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
]
# Add image latent inputs
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
# Add additional batch inputs
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
]
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate height/width from latents
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
if not hasattr(block_state, "image_height"):
block_state.image_height = height
if not hasattr(block_state, "image_width"):
block_state.image_width = width
# 2. Patchify the image latent tensor
# TODO: Implement patchifier for Flux.
latent_height, latent_width = image_latent_tensor.shape[2:]
image_latent_tensor = FluxPipeline._pack_latents(
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
)
# 3. Expand batch size
image_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_tensor=image_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, image_latent_input_name, image_latent_tensor)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, input_name, input_tensor)
self.set_block_state(state, block_state)
return components, state
...@@ -18,21 +18,41 @@ from ..modular_pipeline_utils import InsertableDict ...@@ -18,21 +18,41 @@ from ..modular_pipeline_utils import InsertableDict
from .before_denoise import ( from .before_denoise import (
FluxImg2ImgPrepareLatentsStep, FluxImg2ImgPrepareLatentsStep,
FluxImg2ImgSetTimestepsStep, FluxImg2ImgSetTimestepsStep,
FluxInputStep,
FluxPrepareLatentsStep, FluxPrepareLatentsStep,
FluxRoPEInputsStep,
FluxSetTimestepsStep, FluxSetTimestepsStep,
) )
from .decoders import FluxDecodeStep from .decoders import FluxDecodeStep
from .denoise import FluxDenoiseStep from .denoise import FluxDenoiseStep
from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep from .encoders import FluxProcessImagesInputStep, FluxTextEncoderStep, FluxVaeEncoderDynamicStep
from .inputs import FluxInputsDynamicStep, FluxTextInputStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# vae encoder (run before before_denoise) # vae encoder (run before before_denoise)
FluxImg2ImgVaeEncoderBlocks = InsertableDict(
[
("preprocess", FluxProcessImagesInputStep()),
("encode", FluxVaeEncoderDynamicStep()),
]
)
class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
model_name = "flux"
block_classes = FluxImg2ImgVaeEncoderBlocks.values()
block_names = FluxImg2ImgVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
class FluxAutoVaeEncoderStep(AutoPipelineBlocks): class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [FluxVaeEncoderStep] block_classes = [FluxImg2ImgVaeEncoderStep]
block_names = ["img2img"] block_names = ["img2img"]
block_trigger_inputs = ["image"] block_trigger_inputs = ["image"]
...@@ -41,45 +61,48 @@ class FluxAutoVaeEncoderStep(AutoPipelineBlocks): ...@@ -41,45 +61,48 @@ class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
return ( return (
"Vae encoder step that encode the image inputs into their latent representations.\n" "Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for img2img tasks.\n" + "This is an auto pipeline block that works for img2img tasks.\n"
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided." + " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if `image` is provided, step will be skipped." + " - if `image` is not provided, step will be skipped."
) )
# before_denoise: text2img, img2img # before_denoise: text2img
class FluxBeforeDenoiseStep(SequentialPipelineBlocks): FluxBeforeDenoiseBlocks = InsertableDict(
block_classes = [ [
FluxInputStep, ("prepare_latents", FluxPrepareLatentsStep()),
FluxPrepareLatentsStep, ("set_timesteps", FluxSetTimestepsStep()),
FluxSetTimestepsStep, ("prepare_rope_inputs", FluxRoPEInputsStep()),
] ]
block_names = ["input", "prepare_latents", "set_timesteps"] )
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = FluxBeforeDenoiseBlocks.values()
block_names = FluxBeforeDenoiseBlocks.keys()
@property @property
def description(self): def description(self):
return ( return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `FluxPrepareLatentsStep` is used to prepare the latents\n"
+ " - `FluxSetTimestepsStep` is used to set the timesteps\n"
)
# before_denoise: img2img # before_denoise: img2img
FluxImg2ImgBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxImg2ImgSetTimestepsStep()),
("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
("prepare_rope_inputs", FluxRoPEInputsStep()),
]
)
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep] block_classes = FluxImg2ImgBeforeDenoiseBlocks.values()
block_names = ["input", "set_timesteps", "prepare_latents"] block_names = FluxImg2ImgBeforeDenoiseBlocks.keys()
@property @property
def description(self): def description(self):
return ( return "Before denoise step that prepare the inputs for the denoise step for img2img task."
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
)
# before_denoise: all task (text2img, img2img) # before_denoise: all task (text2img, img2img)
...@@ -113,7 +136,7 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks): ...@@ -113,7 +136,7 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
) )
# decode: all task (text2img, img2img, inpainting) # decode: all task (text2img, img2img)
class FluxAutoDecodeStep(AutoPipelineBlocks): class FluxAutoDecodeStep(AutoPipelineBlocks):
block_classes = [FluxDecodeStep] block_classes = [FluxDecodeStep]
block_names = ["non-inpaint"] block_names = ["non-inpaint"]
...@@ -124,32 +147,73 @@ class FluxAutoDecodeStep(AutoPipelineBlocks): ...@@ -124,32 +147,73 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`" return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
# inputs: text2image/img2img
FluxImg2ImgBlocks = InsertableDict(
[("text_inputs", FluxTextInputStep()), ("additional_inputs", FluxInputsDynamicStep())]
)
class FluxImg2ImgInputStep(SequentialPipelineBlocks):
model_name = "flux"
block_classes = FluxImg2ImgBlocks.values()
block_names = FluxImg2ImgBlocks.keys()
@property
def description(self):
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
class FluxImageAutoInputStep(AutoPipelineBlocks):
block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
)
class FluxCoreDenoiseStep(SequentialPipelineBlocks): class FluxCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [FluxInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep] model_name = "flux"
block_classes = [FluxImageAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"] block_names = ["input", "before_denoise", "denoise"]
@property @property
def description(self): def description(self):
return ( return (
"Core step that performs the denoising process. \n" "Core step that performs the denoising process. \n"
+ " - `FluxInputStep` (input) standardizes the inputs for the denoising step.\n" + " - `FluxImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" + " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n" + " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step support text-to-image and image-to-image tasks for Flux:\n" + "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n" + " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings" + " - for text-to-image generation, all you need to provide is prompt embeddings."
) )
# text2image # Auto blocks (text2image and img2img)
class FluxAutoBlocks(SequentialPipelineBlocks): AUTO_BLOCKS = InsertableDict(
block_classes = [ [
FluxTextEncoderStep, ("text_encoder", FluxTextEncoderStep()),
FluxAutoVaeEncoderStep, ("image_encoder", FluxAutoVaeEncoderStep()),
FluxCoreDenoiseStep, ("denoise", FluxCoreDenoiseStep()),
FluxAutoDecodeStep, ("decode", FluxDecodeStep()),
] ]
block_names = ["text_encoder", "image_encoder", "denoise", "decode"] )
class FluxAutoBlocks(SequentialPipelineBlocks):
model_name = "flux"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
@property @property
def description(self): def description(self):
...@@ -162,35 +226,28 @@ class FluxAutoBlocks(SequentialPipelineBlocks): ...@@ -162,35 +226,28 @@ class FluxAutoBlocks(SequentialPipelineBlocks):
TEXT2IMAGE_BLOCKS = InsertableDict( TEXT2IMAGE_BLOCKS = InsertableDict(
[ [
("text_encoder", FluxTextEncoderStep), ("text_encoder", FluxTextEncoderStep()),
("input", FluxInputStep), ("input", FluxTextInputStep()),
("prepare_latents", FluxPrepareLatentsStep), ("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxSetTimestepsStep), ("set_timesteps", FluxSetTimestepsStep()),
("denoise", FluxDenoiseStep), ("prepare_rope_inputs", FluxRoPEInputsStep()),
("decode", FluxDecodeStep), ("denoise", FluxDenoiseStep()),
("decode", FluxDecodeStep()),
] ]
) )
IMAGE2IMAGE_BLOCKS = InsertableDict( IMAGE2IMAGE_BLOCKS = InsertableDict(
[ [
("text_encoder", FluxTextEncoderStep), ("text_encoder", FluxTextEncoderStep()),
("image_encoder", FluxVaeEncoderStep), ("vae_encoder", FluxVaeEncoderDynamicStep()),
("input", FluxInputStep), ("input", FluxImg2ImgInputStep()),
("set_timesteps", FluxImg2ImgSetTimestepsStep), ("prepare_latents", FluxPrepareLatentsStep()),
("prepare_latents", FluxImg2ImgPrepareLatentsStep), ("set_timesteps", FluxImg2ImgSetTimestepsStep()),
("denoise", FluxDenoiseStep), ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
("decode", FluxDecodeStep), ("prepare_rope_inputs", FluxRoPEInputsStep()),
("denoise", FluxDenoiseStep()),
("decode", FluxDecodeStep()),
] ]
) )
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxAutoVaeEncoderStep),
("denoise", FluxCoreDenoiseStep),
("decode", FluxAutoDecodeStep),
]
)
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS} ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment