Unverified Commit 6ca5a58e authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Community Pipeline] Batched implementation of Flux with CFG (#9513)

* batched implementation of flux cfg.

* style.

* readme

* remove comments.
parent b52684c3
...@@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif ...@@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Example | Description | Code Example | Colab | Author | | Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:| |:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)| |Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) | | HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/toshas/marigold) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) | | Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/toshas/marigold) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |
...@@ -82,6 +83,36 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion ...@@ -82,6 +83,36 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
## Example usages ## Example usages
### Flux with CFG
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
Example usage:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
custom_pipeline="pipeline_flux_with_cfg"
)
pipeline.enable_model_cpu_offload()
prompt = "a watercolor painting of a unicorn"
negative_prompt = "pink"
img = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
true_cfg=1.5,
guidance_scale=3.5,
num_images_per_prompt=1,
generator=torch.manual_seed(0)
).images[0]
img.save("cfg_flux.png")
```
### Differential Diffusion ### Differential Diffusion
**Eran Levin, Ohad Fried** **Eran Levin, Ohad Fried**
......
...@@ -289,80 +289,104 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi ...@@ -289,80 +289,104 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]], prompt_2: Union[str, List[str]],
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Union[str, List[str]] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512, max_sequence_length: int = 512,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
do_true_cfg: bool = False,
): ):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device device = device or self._execution_device
# set lora scale so that monkey patched LoRA # Set LoRA scale if applicable
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND: if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND: if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale) scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if do_true_cfg and negative_prompt is not None:
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_batch_size = len(negative_prompt)
if negative_batch_size != batch_size:
raise ValueError(
f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})"
)
# Concatenate prompts
prompts = prompt + negative_prompt
prompts_2 = (
prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
)
else:
prompts = prompt
prompts_2 = prompt_2
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt if prompts_2 is None:
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompts_2 = prompts
# We only use the pooled prompt output from the CLIPTextModel # Get pooled prompt embeddings from CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds( pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt, prompt=prompts,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
) )
prompt_embeds = self._get_t5_prompt_embeds( prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2, prompt=prompts_2,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
device=device, device=device,
) )
if do_true_cfg and negative_prompt is not None:
# Split embeddings back into positive and negative parts
total_batch_size = batch_size * num_images_per_prompt
positive_indices = slice(0, total_batch_size)
negative_indices = slice(total_batch_size, 2 * total_batch_size)
positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]
negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]
positive_prompt_embeds = prompt_embeds[positive_indices]
negative_prompt_embeds = prompt_embeds[negative_indices]
pooled_prompt_embeds = positive_pooled_prompt_embeds
prompt_embeds = positive_prompt_embeds
# Unscale LoRA layers
if self.text_encoder is not None: if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale) unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None: if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale) unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids if do_true_cfg and negative_prompt is not None:
return (
prompt_embeds,
pooled_prompt_embeds,
text_ids,
negative_prompt_embeds,
negative_pooled_prompt_embeds,
)
else:
return prompt_embeds, pooled_prompt_embeds, text_ids, None, None
def check_inputs( def check_inputs(
self, self,
...@@ -687,38 +711,33 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi ...@@ -687,38 +711,33 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
lora_scale = ( lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
) )
do_true_cfg = true_cfg > 1 and negative_prompt is not None
( (
prompt_embeds, prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
text_ids, text_ids,
negative_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt( ) = self.encode_prompt(
prompt=prompt, prompt=prompt,
prompt_2=prompt_2, prompt_2=prompt_2,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
lora_scale=lora_scale, lora_scale=lora_scale,
do_true_cfg=do_true_cfg,
) )
# perform "real" CFG as suggested for distilled Flux models in https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md
do_true_cfg = true_cfg > 1 and negative_prompt is not None
if do_true_cfg: if do_true_cfg:
( # Concatenate embeddings
negative_prompt_embeds, prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
negative_pooled_prompt_embeds, pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
negative_text_ids,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables # 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 num_channels_latents = self.transformer.config.in_channels // 4
...@@ -754,24 +773,26 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi ...@@ -754,24 +773,26 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# 6. Denoising loop # 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt: if self.interrupt:
continue continue
latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latent_model_input.shape[0])
else:
guidance = None
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latents, hidden_states=latent_model_input,
timestep=timestep / 1000, timestep=timestep / 1000,
guidance=guidance, guidance=guidance,
pooled_projections=pooled_prompt_embeds, pooled_projections=pooled_prompt_embeds,
...@@ -783,18 +804,7 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi ...@@ -783,18 +804,7 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
)[0] )[0]
if do_true_cfg: if do_true_cfg:
neg_noise_pred = self.transformer( neg_noise_pred, noise_pred = noise_pred.chunk(2)
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred) noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment