Unverified Commit 203dc520 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[modular] add Modular flux for text-to-image (#11995)

* start flux.

* more

* up

* up

* up

* up

* get back the deleted files.

* up

* empathy
parent 56d43872
...@@ -364,6 +364,8 @@ except OptionalDependencyNotAvailable: ...@@ -364,6 +364,8 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["modular_pipelines"].extend( _import_structure["modular_pipelines"].extend(
[ [
"FluxAutoBlocks",
"FluxModularPipeline",
"StableDiffusionXLAutoBlocks", "StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline", "StableDiffusionXLModularPipeline",
"WanAutoBlocks", "WanAutoBlocks",
...@@ -999,6 +1001,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -999,6 +1001,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403 from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .modular_pipelines import ( from .modular_pipelines import (
FluxAutoBlocks,
FluxModularPipeline,
StableDiffusionXLAutoBlocks, StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline, StableDiffusionXLModularPipeline,
WanAutoBlocks, WanAutoBlocks,
......
...@@ -107,6 +107,7 @@ class TransformerBlockRegistry: ...@@ -107,6 +107,7 @@ class TransformerBlockRegistry:
def _register_attention_processors_metadata(): def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0 from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
from ..models.transformers.transformer_wan import WanAttnProcessor2_0 from ..models.transformers.transformer_wan import WanAttnProcessor2_0
# AttnProcessor2_0 # AttnProcessor2_0
...@@ -132,6 +133,11 @@ def _register_attention_processors_metadata(): ...@@ -132,6 +133,11 @@ def _register_attention_processors_metadata():
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0, skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
), ),
) )
# FluxAttnProcessor
AttentionProcessorRegistry.register(
model_class=FluxAttnProcessor,
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
)
def _register_transformer_blocks_metadata(): def _register_transformer_blocks_metadata():
...@@ -271,4 +277,6 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, * ...@@ -271,4 +277,6 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states _skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on # fmt: on
...@@ -41,6 +41,7 @@ else: ...@@ -41,6 +41,7 @@ else:
] ]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
_import_structure["components_manager"] = ["ComponentsManager"] _import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -51,6 +52,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -51,6 +52,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_pt_objects import * # noqa F403 from ..utils.dummy_pt_objects import * # noqa F403
else: else:
from .components_manager import ComponentsManager from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxModularPipeline
from .modular_pipeline import ( from .modular_pipeline import (
AutoPipelineBlocks, AutoPipelineBlocks,
BlockState, BlockState,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["FluxTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
"FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
]
_import_structure["modular_pipeline"] = ["FluxModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import FluxTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
TEXT2IMAGE_BLOCKS,
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
)
from .modular_pipeline import FluxModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Union
import numpy as np
import torch
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
class FluxInputStep(PipelineBlock):
model_name = "flux"
@property
def description(self) -> str:
return (
"Input processing step that:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
"have a final batch_size of batch_size * num_images_per_prompt."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
),
# TODO: support negative embeddings?
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation",
),
# TODO: support negative embeddings?
]
def check_inputs(self, components, block_state):
if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
raise ValueError(
"`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
f" {block_state.pooled_prompt_embeds.shape}."
)
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
# TODO: consider adding negative embeddings?
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
class FluxSetTimestepsStep(PipelineBlock):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("guidance_scale", default=3.5),
InputParam("latents", type_hint=torch.Tensor),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
)
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
]
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
scheduler = components.scheduler
latents = block_state.latents
image_seq_len = latents.shape[1]
num_inference_steps = block_state.num_inference_steps
sigmas = block_state.sigmas
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
sigmas = None
block_state.sigmas = sigmas
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
)
if components.transformer.config.guidance_embeds:
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
block_state.guidance = guidance
self.set_block_state(state, block_state)
return components, state
class FluxPrepareLatentsStep(PipelineBlock):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
),
OutputParam(
"latent_image_ids",
type_hint=torch.Tensor,
description="IDs computed from the image sequence needed for RoPE",
),
]
@staticmethod
def check_inputs(components, block_state):
if (block_state.height is not None and block_state.height % (components.vae_scale_factor * 2) != 0) or (
block_state.width is not None and block_state.width % (components.vae_scale_factor * 2) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}."
)
@staticmethod
def prepare_latents(
comp,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over
# the packing methods here. So, for example, `comp._pack_latents()` won't work if we were
# to go with the "# Copied from ..." approach. Or maybe there's a way?
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
components,
block_state.batch_size * block_state.num_images_per_prompt,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.latents,
)
self.set_block_state(state, block_state)
return components, state
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL
from ...utils import logging
from ...video_processor import VaeImageProcessor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
class FluxDecodeStep(PipelineBlock):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam("height", default=1024),
InputParam("width", default=1024),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae = components.vae
if not block_state.output_type == "latent":
latents = block_state.latents
latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
else:
block_state.images = block_state.latents
self.set_block_state(state, block_state)
return components, state
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
import torch
from ...models import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
PipelineBlock,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxLoopDenoiser(PipelineBlock):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", FluxTransformer2DModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `FluxDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [InputParam("joint_attention_kwargs")]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"guidance",
required=True,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Prompt embeddings",
),
InputParam(
"pooled_prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Pooled prompt embeddings",
),
InputParam(
"text_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from text sequence needed for RoPE",
),
InputParam(
"latent_image_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from image sequence needed for RoPE",
),
# TODO: guidance
]
@torch.no_grad()
def __call__(
self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
noise_pred = components.transformer(
hidden_states=block_state.latents,
timestep=t.flatten() / 1000,
guidance=block_state.guidance,
encoder_hidden_states=block_state.prompt_embeds,
pooled_projections=block_state.pooled_prompt_embeds,
joint_attention_kwargs=block_state.joint_attention_kwargs,
txt_ids=block_state.text_ids,
img_ids=block_state.latent_image_ids,
return_dict=False,
)[0]
block_state.noise_pred = noise_pred
return components, block_state
class FluxLoopAfterDenoiser(PipelineBlock):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `FluxDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return []
@property
def intermediate_inputs(self) -> List[str]:
return [InputParam("generator")]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# Perform scheduler step using the predicted output
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "flux"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoise the latents over `timesteps`. "
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec("transformer", FluxTransformer2DModel),
]
@property
def loop_intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
components.scheduler.set_begin_index(0)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
self.set_block_state(state, block_state)
return components, state
class FluxDenoiseStep(FluxDenoiseLoopWrapper):
block_classes = [FluxLoopDenoiser, FluxLoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `FluxLoopDenoiser`\n"
" - `FluxLoopAfterDenoiser`\n"
"This block supports text2image tasks."
)
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from typing import List, Optional, Union
import regex as re
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
if is_ftfy_available():
import ftfy
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
class FluxTextEncoderStep(PipelineBlock):
model_name = "flux"
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("text_encoder_2", T5EncoderModel),
ComponentSpec("tokenizer_2", T5TokenizerFast),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("joint_attention_kwargs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"text_ids",
type_hint=torch.Tensor,
description="ids from the text sequence for RoPE",
),
]
@staticmethod
def check_inputs(block_state):
for prompt in [block_state.prompt, block_state.prompt_2]:
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
def _get_t5_prompt_embeds(
components,
prompt: Union[str, List[str]],
num_images_per_prompt: int,
max_sequence_length: int,
device: torch.device,
):
dtype = components.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2)
text_inputs = components.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = components.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = components.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
@staticmethod
def _get_clip_prompt_embeds(
components,
prompt: Union[str, List[str]],
num_images_per_prompt: int,
device: torch.device,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, components.tokenizer)
text_inputs = components.tokenizer(
prompt,
padding="max_length",
max_length=components.tokenizer.model_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
tokenizer_max_length = components.tokenizer.model_max_length
untruncated_ids = components.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = components.tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = components.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
@staticmethod
def encode_prompt(
components,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = FluxTextEncoderStep._get_clip_prompt_embeds(
components,
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds(
components,
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if components.text_encoder is not None:
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
# Encode input prompt
block_state.text_encoder_lora_scale = (
block_state.joint_attention_kwargs.get("scale", None)
if block_state.joint_attention_kwargs is not None
else None
)
(block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt(
components,
prompt=block_state.prompt,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=block_state.device,
num_images_per_prompt=1, # hardcoded for now.
lora_scale=block_state.text_encoder_lora_scale,
)
# Add outputs
self.set_block_state(state, block_state)
return components, state
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep
from .decoders import FluxDecodeStep
from .denoise import FluxDenoiseStep
from .encoders import FluxTextEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# before_denoise: text2vid
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
FluxInputStep,
FluxPrepareLatentsStep,
FluxSetTimestepsStep,
]
block_names = ["input", "prepare_latents", "set_timesteps"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `FluxPrepareLatentsStep` is used to prepare the latents\n"
+ " - `FluxSetTimestepsStep` is used to set the timesteps\n"
)
# before_denoise: all task (text2vid,)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxBeforeDenoiseStep]
block_names = ["text2image"]
block_trigger_inputs = [None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
)
# denoise: text2image
class FluxAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image tasks."
" - `FluxDenoiseStep` (denoise) for text2image tasks."
)
# decode: all task (text2img, img2img, inpainting)
class FluxAutoDecodeStep(AutoPipelineBlocks):
block_classes = [FluxDecodeStep]
block_names = ["non-inpaint"]
block_trigger_inputs = [None]
@property
def description(self):
return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`"
# text2image
class FluxAutoBlocks(SequentialPipelineBlocks):
block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep]
block_names = ["text_encoder", "before_denoise", "denoise", "decoder"]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image using Flux.\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("input", FluxInputStep),
("prepare_latents", FluxPrepareLatentsStep),
# Setting it after preparation of latents because we rely on `latents`
# to calculate `img_seq_len` for `shift`.
("set_timesteps", FluxSetTimestepsStep),
("denoise", FluxDenoiseStep),
("decode", FluxDecodeStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("before_denoise", FluxAutoBeforeDenoiseStep),
("denoise", FluxAutoDenoiseStep),
("decode", FluxAutoDecodeStep),
]
)
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...loaders import FluxLoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
"""
A ModularPipeline for Flux.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 16
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
...@@ -61,6 +61,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict( ...@@ -61,6 +61,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
[ [
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"), ("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
] ]
) )
...@@ -68,6 +69,7 @@ MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict( ...@@ -68,6 +69,7 @@ MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
[ [
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"), ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
("WanModularPipeline", "WanAutoBlocks"), ("WanModularPipeline", "WanAutoBlocks"),
("FluxModularPipeline", "FluxAutoBlocks"),
] ]
) )
...@@ -1663,7 +1665,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1663,7 +1665,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
if input_param.name: if input_param.name:
value = state.get_intermediate(input_param.name) value = state.get_intermediate(input_param.name)
if input_param.required and value is None: if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing") raise ValueError(f"Required intermediate input '{input_param.name}' is missing.")
elif value is not None or (value is None and input_param.name not in data): elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value data[input_param.name] = value
elif input_param.kwargs_type: elif input_param.kwargs_type:
......
...@@ -11,12 +11,14 @@ from ...utils import BaseOutput ...@@ -11,12 +11,14 @@ from ...utils import BaseOutput
@dataclass @dataclass
class FluxPipelineOutput(BaseOutput): class FluxPipelineOutput(BaseOutput):
""" """
Output class for Stable Diffusion pipelines. Output class for Flux image generation pipelines.
Args: Args:
images (`List[PIL.Image.Image]` or `np.ndarray`) images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
passed to the decoder.
""" """
images: Union[List[PIL.Image.Image], np.ndarray] images: Union[List[PIL.Image.Image], np.ndarray]
......
...@@ -2,6 +2,36 @@ ...@@ -2,6 +2,36 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class FluxAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLAutoBlocks(metaclass=DummyObject): class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment