Unverified Commit 7ac6e286 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Flux Fill, Canny, Depth, Redux (#9985)



* update

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent b5fd6f13
...@@ -22,12 +22,20 @@ Flux can be quite expensive to run on consumer hardware devices. However, you ca ...@@ -22,12 +22,20 @@ Flux can be quite expensive to run on consumer hardware devices. However, you ca
</Tip> </Tip>
Flux comes in two variants: Flux comes in the following variants:
* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`) | model type | model id |
* Guidance-distilled (`black-forest-labs/FLUX.1-dev`) |:----------:|:--------:|
| Timestep-distilled | [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) |
| Guidance-distilled | [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev) |
| Fill Inpainting/Outpainting (Guidance-distilled) | [`black-forest-labs/FLUX.1-Fill-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) |
| Canny Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Canny-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
| Depth Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
Both checkpoints have slightly difference usage which we detail below. All checkpoints have different usage which we detail below.
### Timestep-distilled ### Timestep-distilled
...@@ -77,7 +85,132 @@ out = pipe( ...@@ -77,7 +85,132 @@ out = pipe(
out.save("image.png") out.save("image.png")
``` ```
### Fill Inpainting/Outpainting
* Flux Fill pipeline does not require `strength` as an input like regular inpainting pipelines.
* It supports both inpainting and outpainting.
```python
import torch
from diffusers import FluxFillPipeline
from diffusers.utils import load_image
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png")
repo_id = "black-forest-labs/FLUX.1-Fill-dev"
pipe = FluxFillPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda")
image = pipe(
prompt="a white paper cup",
image=image,
mask_image=mask,
height=1632,
width=1232,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save(f"output.png")
```
### Canny Control
**Note:** `black-forest-labs/Flux.1-Canny-dev` is _not_ a [`ControlNetModel`] model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
```python
# !pip install -U controlnet-aux
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
from diffusers.utils import load_image
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = CannyDetector()
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
image = pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30.0,
).images[0]
image.save("output.png")
```
### Depth Control
**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
```python
# !pip install git+https://github.com/asomoza/image_gen_aux.git
import torch
from diffusers import FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10.0,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")
```
### Redux
* Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation.
* You can first use the `FluxPriorReduxPipeline` to get the `prompt_embeds` and `pooled_prompt_embeds`, and then feed them into the `FluxPipeline` for image-to-image generation.
* When use `FluxPriorReduxPipeline` with a base pipeline, you can set `text_encoder=None` and `text_encoder_2=None` in the base pipeline, in order to save VRAM.
```python
import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers.utils import load_image
device = "cuda"
dtype = torch.bfloat16
repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
repo_base = "black-forest-labs/FLUX.1-dev"
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
pipe = FluxPipeline.from_pretrained(
repo_base,
text_encoder=None,
text_encoder_2=None,
torch_dtype=torch.bfloat16
).to(device)
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(
guidance_scale=2.5,
num_inference_steps=50,
generator=torch.Generator("cpu").manual_seed(0),
**pipe_prior_output,
).images
images[0].save("flux-redux.png")
```
## Running FP16 inference ## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
FP16 inference code: FP16 inference code:
...@@ -188,3 +321,15 @@ image.save("flux-fp8-dev.png") ...@@ -188,3 +321,15 @@ image.save("flux-fp8-dev.png")
[[autodoc]] FluxControlNetImg2ImgPipeline [[autodoc]] FluxControlNetImg2ImgPipeline
- all - all
- __call__ - __call__
## FluxControlPipeline
[[autodoc]] FluxControlPipeline
- all
- __call__
## FluxControlImg2ImgPipeline
[[autodoc]] FluxControlImg2ImgPipeline
- all
- __call__
...@@ -37,6 +37,8 @@ parser = argparse.ArgumentParser() ...@@ -37,6 +37,8 @@ parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str) parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str) parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str) parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--in_channels", type=int, default=64)
parser.add_argument("--out_channels", type=int, default=None)
parser.add_argument("--vae", action="store_true") parser.add_argument("--vae", action="store_true")
parser.add_argument("--transformer", action="store_true") parser.add_argument("--transformer", action="store_true")
parser.add_argument("--output_path", type=str) parser.add_argument("--output_path", type=str)
...@@ -279,10 +281,13 @@ def main(args): ...@@ -279,10 +281,13 @@ def main(args):
num_single_layers = 38 num_single_layers = 38
inner_dim = 3072 inner_dim = 3072
mlp_ratio = 4.0 mlp_ratio = 4.0
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers( converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
) )
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance) transformer = FluxTransformer2DModel(
in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
)
transformer.load_state_dict(converted_transformer_state_dict, strict=True) transformer.load_state_dict(converted_transformer_state_dict, strict=True)
print( print(
......
...@@ -269,12 +269,16 @@ else: ...@@ -269,12 +269,16 @@ else:
"CogVideoXVideoToVideoPipeline", "CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline", "CogView3PlusPipeline",
"CycleDiffusionPipeline", "CycleDiffusionPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlNetImg2ImgPipeline", "FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline", "FluxControlNetInpaintPipeline",
"FluxControlNetPipeline", "FluxControlNetPipeline",
"FluxControlPipeline",
"FluxFillPipeline",
"FluxImg2ImgPipeline", "FluxImg2ImgPipeline",
"FluxInpaintPipeline", "FluxInpaintPipeline",
"FluxPipeline", "FluxPipeline",
"FluxPriorReduxPipeline",
"HunyuanDiTControlNetPipeline", "HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline", "HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline", "HunyuanDiTPipeline",
...@@ -321,6 +325,7 @@ else: ...@@ -321,6 +325,7 @@ else:
"PixArtAlphaPipeline", "PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline", "PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline", "PixArtSigmaPipeline",
"ReduxImageEncoder",
"SemanticStableDiffusionPipeline", "SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline", "ShapEImg2ImgPipeline",
"ShapEPipeline", "ShapEPipeline",
...@@ -734,12 +739,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -734,12 +739,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXVideoToVideoPipeline, CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline, CogView3PlusPipeline,
CycleDiffusionPipeline, CycleDiffusionPipeline,
FluxControlImg2ImgPipeline,
FluxControlNetImg2ImgPipeline, FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline, FluxControlNetInpaintPipeline,
FluxControlNetPipeline, FluxControlNetPipeline,
FluxControlPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline, FluxImg2ImgPipeline,
FluxInpaintPipeline, FluxInpaintPipeline,
FluxPipeline, FluxPipeline,
FluxPriorReduxPipeline,
HunyuanDiTControlNetPipeline, HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline, HunyuanDiTPAGPipeline,
HunyuanDiTPipeline, HunyuanDiTPipeline,
...@@ -786,6 +795,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -786,6 +795,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline, PixArtAlphaPipeline,
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
PixArtSigmaPipeline, PixArtSigmaPipeline,
ReduxImageEncoder,
SemanticStableDiffusionPipeline, SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline, ShapEImg2ImgPipeline,
ShapEPipeline, ShapEPipeline,
......
...@@ -238,6 +238,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -238,6 +238,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self, self,
patch_size: int = 1, patch_size: int = 1,
in_channels: int = 64, in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19, num_layers: int = 19,
num_single_layers: int = 38, num_single_layers: int = 38,
attention_head_dim: int = 128, attention_head_dim: int = 128,
...@@ -248,7 +249,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -248,7 +249,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
axes_dims_rope: Tuple[int] = (16, 56, 56), axes_dims_rope: Tuple[int] = (16, 56, 56),
): ):
super().__init__() super().__init__()
self.out_channels = in_channels self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
...@@ -261,7 +262,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -261,7 +262,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
) )
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[ [
...@@ -449,6 +450,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -449,6 +450,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
logger.warning( logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
) )
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000 timestep = timestep.to(hidden_states.dtype) * 1000
...@@ -456,6 +458,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -456,6 +458,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
guidance = guidance.to(hidden_states.dtype) * 1000 guidance = guidance.to(hidden_states.dtype) * 1000
else: else:
guidance = None guidance = None
temb = ( temb = (
self.time_text_embed(timestep, pooled_projections) self.time_text_embed(timestep, pooled_projections)
if guidance is None if guidance is None
......
...@@ -127,12 +127,17 @@ else: ...@@ -127,12 +127,17 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline", "AnimateDiffVideoToVideoControlNetPipeline",
] ]
_import_structure["flux"] = [ _import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlNetPipeline", "FluxControlNetPipeline",
"FluxControlNetImg2ImgPipeline", "FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline", "FluxControlNetInpaintPipeline",
"FluxImg2ImgPipeline", "FluxImg2ImgPipeline",
"FluxInpaintPipeline", "FluxInpaintPipeline",
"FluxPipeline", "FluxPipeline",
"FluxFillPipeline",
"FluxPriorReduxPipeline",
"ReduxImageEncoder",
] ]
_import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [ _import_structure["audioldm2"] = [
...@@ -521,12 +526,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -521,12 +526,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQDiffusionPipeline, VQDiffusionPipeline,
) )
from .flux import ( from .flux import (
FluxControlImg2ImgPipeline,
FluxControlNetImg2ImgPipeline, FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline, FluxControlNetInpaintPipeline,
FluxControlNetPipeline, FluxControlNetPipeline,
FluxControlPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline, FluxImg2ImgPipeline,
FluxInpaintPipeline, FluxInpaintPipeline,
FluxPipeline, FluxPipeline,
FluxPriorReduxPipeline,
ReduxImageEncoder,
) )
from .hunyuandit import HunyuanDiTPipeline from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline from .i2vgen_xl import I2VGenXLPipeline
......
...@@ -12,7 +12,7 @@ from ...utils import ( ...@@ -12,7 +12,7 @@ from ...utils import (
_dummy_objects = {} _dummy_objects = {}
_additional_imports = {} _additional_imports = {}
_import_structure = {"pipeline_output": ["FluxPipelineOutput"]} _import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
...@@ -22,12 +22,17 @@ except OptionalDependencyNotAvailable: ...@@ -22,12 +22,17 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["modeling_flux"] = ["ReduxImageEncoder"]
_import_structure["pipeline_flux"] = ["FluxPipeline"] _import_structure["pipeline_flux"] = ["FluxPipeline"]
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
...@@ -35,12 +40,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -35,12 +40,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .modeling_flux import ReduxImageEncoder
from .pipeline_flux import FluxPipeline from .pipeline_flux import FluxPipeline
from .pipeline_flux_control import FluxControlPipeline
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
from .pipeline_flux_controlnet import FluxControlNetPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
from .pipeline_flux_fill import FluxFillPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
else: else:
import sys import sys
......
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
from ...utils import BaseOutput
@dataclass
class ReduxImageEncoderOutput(BaseOutput):
image_embeds: Optional[torch.Tensor] = None
class ReduxImageEncoder(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
) -> None:
super().__init__()
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput:
projected_x = self.redux_down(nn.functional.silu(self.redux_up(x)))
return ReduxImageEncoderOutput(image_embeds=projected_x)
This diff is collapsed.
This diff is collapsed.
...@@ -750,6 +750,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -750,6 +750,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
device = self._execution_device device = self._execution_device
dtype = self.transformer.dtype dtype = self.transformer.dtype
# 3. Prepare text embeddings
lora_scale = ( lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
) )
......
This diff is collapsed.
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from PIL import Image
from transformers import (
CLIPTextModel,
CLIPTokenizer,
SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import PipelineImageInput
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ..pipeline_utils import DiffusionPipeline
from .modeling_flux import ReduxImageEncoder
from .pipeline_output import FluxPriorReduxPipelineOutput
if is_torch_xla_available():
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import FluxPriorReduxPipeline, FluxPipeline
>>> from diffusers.utils import load_image
>>> device = "cuda"
>>> dtype = torch.bfloat16
>>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
>>> repo_base = "black-forest-labs/FLUX.1-dev"
>>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
>>> pipe = FluxPipeline.from_pretrained(
... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16
... ).to(device)
>>> image = load_image(
... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
... )
>>> pipe_prior_output = pipe_prior_redux(image)
>>> images = pipe(
... guidance_scale=2.5,
... num_inference_steps=50,
... generator=torch.Generator("cpu").manual_seed(0),
... **pipe_prior_output,
... ).images
>>> images[0].save("flux-redux.png")
```
"""
class FluxPriorReduxPipeline(DiffusionPipeline):
r"""
The Flux Redux pipeline for image-to-image generation.
Reference: https://blackforestlabs.ai/flux-1-tools/
Args:
image_encoder ([`SiglipVisionModel`]):
SIGLIP vision model to encode the input image.
feature_extractor ([`SiglipImageProcessor`]):
Image processor for preprocessing images for the SIGLIP model.
image_embedder ([`ReduxImageEncoder`]):
Redux image encoder to process the SIGLIP embeddings.
text_encoder ([`CLIPTextModel`], *optional*):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`], *optional*):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`, *optional*):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`, *optional*):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "image_encoder->image_embedder"
_optional_components = [
"text_encoder",
"tokenizer",
"text_encoder_2",
"tokenizer_2",
]
_callback_tensor_inputs = []
def __init__(
self,
image_encoder: SiglipVisionModel,
feature_extractor: SiglipImageProcessor,
image_embedder: ReduxImageEncoder,
text_encoder: CLIPTextModel = None,
tokenizer: CLIPTokenizer = None,
text_encoder_2: T5EncoderModel = None,
tokenizer_2: T5TokenizerFast = None,
):
super().__init__()
self.register_modules(
image_encoder=image_encoder,
feature_extractor=feature_extractor,
image_embedder=image_embedder,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
image = self.feature_extractor.preprocess(
images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True
)
image = image.to(device=device, dtype=dtype)
image_enc_hidden_states = self.image_encoder(**image).last_hidden_state
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
return image_enc_hidden_states
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput,
return_dict: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`:
[`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
# 2. Define call parameters
if image is not None and isinstance(image, Image.Image):
batch_size = 1
elif image is not None and isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
device = self._execution_device
# 3. Prepare image embeddings
image_latents = self.encode_image(image, device, 1)
image_embeds = self.image_embedder(image_latents).image_embeds
image_embeds = image_embeds.to(device=device)
# 3. Prepare (dummy) text embeddings
if hasattr(self, "text_encoder") and self.text_encoder is not None:
(
prompt_embeds,
pooled_prompt_embeds,
_,
) = self.encode_prompt(
prompt=[""] * batch_size,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=device,
num_images_per_prompt=1,
max_sequence_length=512,
lora_scale=None,
)
else:
# max_sequence_length is 512, t5 encoder hidden size is 4096
prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype)
# pooled_prompt_embeds is 768, clip text encoder hidden size
pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
# Concatenate image and text embeddings
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (prompt_embeds, pooled_prompt_embeds)
return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds)
...@@ -3,6 +3,7 @@ from typing import List, Union ...@@ -3,6 +3,7 @@ from typing import List, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch
from ...utils import BaseOutput from ...utils import BaseOutput
...@@ -19,3 +20,18 @@ class FluxPipelineOutput(BaseOutput): ...@@ -19,3 +20,18 @@ class FluxPipelineOutput(BaseOutput):
""" """
images: Union[List[PIL.Image.Image], np.ndarray] images: Union[List[PIL.Image.Image], np.ndarray]
@dataclass
class FluxPriorReduxPipelineOutput(BaseOutput):
"""
Output class for Flux Prior Redux pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
prompt_embeds: torch.Tensor
pooled_prompt_embeds: torch.Tensor
...@@ -377,6 +377,21 @@ class CycleDiffusionPipeline(metaclass=DummyObject): ...@@ -377,6 +377,21 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class FluxControlImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxControlNetImg2ImgPipeline(metaclass=DummyObject): class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
...@@ -422,6 +437,36 @@ class FluxControlNetPipeline(metaclass=DummyObject): ...@@ -422,6 +437,36 @@ class FluxControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class FluxControlPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxFillPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxImg2ImgPipeline(metaclass=DummyObject): class FluxImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
...@@ -467,6 +512,21 @@ class FluxPipeline(metaclass=DummyObject): ...@@ -467,6 +512,21 @@ class FluxPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class FluxPriorReduxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTControlNetPipeline(metaclass=DummyObject): class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
...@@ -1157,6 +1217,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject): ...@@ -1157,6 +1217,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class ReduxImageEncoder(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class SemanticStableDiffusionPipeline(metaclass=DummyObject): class SemanticStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxControlPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=8,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
control_image = Image.new("RGB", (16, 16), 0)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"control_image": control_image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs
def test_flux_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_flux_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
prompt,
prompt_2=None,
device=torch_device,
max_sequence_length=inputs["max_sequence_length"],
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxControlImg2ImgPipeline,
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class FluxControlImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxControlImg2ImgPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=8,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
image = Image.new("RGB", (16, 16), 0)
control_image = Image.new("RGB", (16, 16), 0)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"control_image": control_image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"strength": 0.8,
"output_type": "np",
}
return inputs
def test_flux_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_flux_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
prompt,
prompt_2=None,
device=torch_device,
max_sequence_length=inputs["max_sequence_length"],
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
import random
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxFillPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxFillPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=20,
out_channels=8,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=2,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
mask_image = torch.ones((1, 1, 32, 32)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"mask_image": mask_image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 32,
"width": 32,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs
def test_flux_fill_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_flux_fill_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
prompt,
prompt_2=None,
device=torch_device,
max_sequence_length=inputs["max_sequence_length"],
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
import gc
import unittest
import numpy as np
import pytest
import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
slow,
torch_device,
)
@slow
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline
repo_id = "YiYiXu/yiyi-redux" # update to "black-forest-labs/FLUX.1-Redux-dev" once PR is merged
base_pipeline_class = FluxPipeline
base_repo_id = "black-forest-labs/FLUX.1-schnell"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
init_image = load_image(
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
)
return {"image": init_image}
def get_base_pipeline_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
return {
"num_inference_steps": 2,
"guidance_scale": 2.0,
"output_type": "np",
"generator": generator,
}
def test_flux_redux_inference(self):
pipe_redux = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
pipe_base = self.base_pipeline_class.from_pretrained(
self.base_repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
)
pipe_redux.to(torch_device)
pipe_base.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
base_pipeline_inputs = self.get_base_pipeline_inputs(torch_device)
redux_pipeline_output = pipe_redux(**inputs)
image = pipe_base(**base_pipeline_inputs, **redux_pipeline_output).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
0.30078125,
0.37890625,
0.46875,
0.28125,
0.36914062,
0.47851562,
0.28515625,
0.375,
0.4765625,
0.28125,
0.375,
0.48046875,
0.27929688,
0.37695312,
0.47851562,
0.27734375,
0.38085938,
0.4765625,
0.2734375,
0.38085938,
0.47265625,
0.27539062,
0.37890625,
0.47265625,
0.27734375,
0.37695312,
0.47070312,
0.27929688,
0.37890625,
0.47460938,
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment