"examples/pytorch/argo/ogb_example.py" did not exist on "f5eb80d221fec8690e8cfb087256671545bb9a5a"
Unverified Commit 66394bf6 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Chroma Follow Up (#11725)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* updte

* update

* update

* update
parent 62cce304
...@@ -353,6 +353,7 @@ else: ...@@ -353,6 +353,7 @@ else:
"AuraFlowPipeline", "AuraFlowPipeline",
"BlipDiffusionControlNetPipeline", "BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline", "BlipDiffusionPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline", "ChromaPipeline",
"CLIPImageProjection", "CLIPImageProjection",
"CogVideoXFunControlPipeline", "CogVideoXFunControlPipeline",
...@@ -945,6 +946,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -945,6 +946,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel, AudioLDM2UNet2DConditionModel,
AudioLDMPipeline, AudioLDMPipeline,
AuraFlowPipeline, AuraFlowPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline, ChromaPipeline,
CLIPImageProjection, CLIPImageProjection,
CogVideoXFunControlPipeline, CogVideoXFunControlPipeline,
......
...@@ -2543,7 +2543,9 @@ class FusedFluxAttnProcessor2_0: ...@@ -2543,7 +2543,9 @@ class FusedFluxAttnProcessor2_0:
query = apply_rotary_emb(query, image_rotary_emb) query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
...@@ -2776,7 +2778,9 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): ...@@ -2776,7 +2778,9 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
query = apply_rotary_emb(query, image_rotary_emb) query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
......
...@@ -250,15 +250,21 @@ class ChromaSingleTransformerBlock(nn.Module): ...@@ -250,15 +250,21 @@ class ChromaSingleTransformerBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb) norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {} joint_attention_kwargs = joint_attention_kwargs or {}
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
attn_output = self.attn( attn_output = self.attn(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs, **joint_attention_kwargs,
) )
...@@ -312,6 +318,7 @@ class ChromaTransformerBlock(nn.Module): ...@@ -312,6 +318,7 @@ class ChromaTransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
temb_img, temb_txt = temb[:, :6], temb[:, 6:] temb_img, temb_txt = temb[:, :6], temb[:, 6:]
...@@ -321,11 +328,15 @@ class ChromaTransformerBlock(nn.Module): ...@@ -321,11 +328,15 @@ class ChromaTransformerBlock(nn.Module):
encoder_hidden_states, emb=temb_txt encoder_hidden_states, emb=temb_txt
) )
joint_attention_kwargs = joint_attention_kwargs or {} joint_attention_kwargs = joint_attention_kwargs or {}
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
# Attention. # Attention.
attention_outputs = self.attn( attention_outputs = self.attn(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs, **joint_attention_kwargs,
) )
...@@ -570,6 +581,7 @@ class ChromaTransformer2DModel( ...@@ -570,6 +581,7 @@ class ChromaTransformer2DModel(
timestep: torch.LongTensor = None, timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None, img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None, txt_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None, controlnet_block_samples=None,
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
...@@ -659,11 +671,7 @@ class ChromaTransformer2DModel( ...@@ -659,11 +671,7 @@ class ChromaTransformer2DModel(
) )
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block, block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
) )
else: else:
...@@ -672,6 +680,7 @@ class ChromaTransformer2DModel( ...@@ -672,6 +680,7 @@ class ChromaTransformer2DModel(
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs, joint_attention_kwargs=joint_attention_kwargs,
) )
...@@ -704,6 +713,7 @@ class ChromaTransformer2DModel( ...@@ -704,6 +713,7 @@ class ChromaTransformer2DModel(
hidden_states=hidden_states, hidden_states=hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs, joint_attention_kwargs=joint_attention_kwargs,
) )
......
...@@ -148,7 +148,7 @@ else: ...@@ -148,7 +148,7 @@ else:
"AudioLDM2UNet2DConditionModel", "AudioLDM2UNet2DConditionModel",
] ]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["chroma"] = ["ChromaPipeline"] _import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
_import_structure["cogvideo"] = [ _import_structure["cogvideo"] = [
"CogVideoXPipeline", "CogVideoXPipeline",
"CogVideoXImageToVideoPipeline", "CogVideoXImageToVideoPipeline",
...@@ -537,7 +537,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -537,7 +537,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
) )
from .aura_flow import AuraFlowPipeline from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline from .blip_diffusion import BlipDiffusionPipeline
from .chroma import ChromaPipeline from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import ( from .cogvideo import (
CogVideoXFunControlPipeline, CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline, CogVideoXImageToVideoPipeline,
......
...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable: ...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_chroma"] = ["ChromaPipeline"] _import_structure["pipeline_chroma"] = ["ChromaPipeline"]
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
...@@ -31,6 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -31,6 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .pipeline_chroma import ChromaPipeline from .pipeline_chroma import ChromaPipeline
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
else: else:
import sys import sys
......
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -52,12 +52,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -52,12 +52,21 @@ EXAMPLE_DOC_STRING = """
>>> import torch >>> import torch
>>> from diffusers import ChromaPipeline >>> from diffusers import ChromaPipeline
>>> pipe = ChromaPipeline.from_single_file( >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
... "chroma-unlocked-v35-detail-calibrated.safetensors", torch_dtype=torch.bfloat16 >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-schnell",
... transformer=transformer,
... text_encoder=text_encoder,
... tokenizer=tokenizer,
... torch_dtype=torch.bfloat16,
... ) ... )
>>> pipe.to("cuda") >>> pipe.enable_model_cpu_offload()
>>> prompt = "A cat holding a sign that says hello world" >>> prompt = "A cat holding a sign that says hello world"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] >>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
>>> image.save("chroma.png") >>> image.save("chroma.png")
``` ```
""" """
...@@ -235,6 +244,7 @@ class ChromaPipeline( ...@@ -235,6 +244,7 @@ class ChromaPipeline(
dtype = self.text_encoder.dtype dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
...@@ -242,7 +252,10 @@ class ChromaPipeline( ...@@ -242,7 +252,10 @@ class ChromaPipeline(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds attention_mask = attention_mask.repeat(1, num_images_per_prompt)
attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, attention_mask
def encode_prompt( def encode_prompt(
self, self,
...@@ -250,8 +263,10 @@ class ChromaPipeline( ...@@ -250,8 +263,10 @@ class ChromaPipeline(
negative_prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
max_sequence_length: int = 512, max_sequence_length: int = 512,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
...@@ -268,7 +283,7 @@ class ChromaPipeline( ...@@ -268,7 +283,7 @@ class ChromaPipeline(
torch device torch device
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
number of images that should be generated per prompt number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*): lora_scale (`float`, *optional*):
...@@ -293,7 +308,7 @@ class ChromaPipeline( ...@@ -293,7 +308,7 @@ class ChromaPipeline(
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds( prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt, prompt=prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
...@@ -323,12 +338,13 @@ class ChromaPipeline( ...@@ -323,12 +338,13 @@ class ChromaPipeline(
" the batch size of `prompt`." " the batch size of `prompt`."
) )
negative_prompt_embeds = self._get_t5_prompt_embeds( negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt, prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
device=device, device=device,
) )
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
if self.text_encoder is not None: if self.text_encoder is not None:
...@@ -336,7 +352,14 @@ class ChromaPipeline( ...@@ -336,7 +352,14 @@ class ChromaPipeline(
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids return (
prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_embeds,
negative_text_ids,
negative_prompt_attention_mask,
)
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt): def encode_image(self, image, device, num_images_per_prompt):
...@@ -394,7 +417,9 @@ class ChromaPipeline( ...@@ -394,7 +417,9 @@ class ChromaPipeline(
width, width,
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None, max_sequence_length=None,
): ):
...@@ -428,6 +453,14 @@ class ChromaPipeline( ...@@ -428,6 +453,14 @@ class ChromaPipeline(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
) )
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError(
"Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
)
if max_sequence_length is not None and max_sequence_length > 512: if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
...@@ -534,6 +567,25 @@ class ChromaPipeline( ...@@ -534,6 +567,25 @@ class ChromaPipeline(
return latents, latent_image_ids return latents, latent_image_ids
def _prepare_attention_mask(
self,
batch_size,
sequence_length,
dtype,
attention_mask=None,
):
if attention_mask is None:
return attention_mask
# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
dim=1,
)
attention_mask = attention_mask.to(dtype)
return attention_mask
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
...@@ -566,18 +618,20 @@ class ChromaPipeline( ...@@ -566,18 +618,20 @@ class ChromaPipeline(
negative_prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 28, num_inference_steps: int = 35,
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5, guidance_scale: float = 5.0,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -618,11 +672,11 @@ class ChromaPipeline( ...@@ -618,11 +672,11 @@ class ChromaPipeline(
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
...@@ -636,10 +690,18 @@ class ChromaPipeline( ...@@ -636,10 +690,18 @@ class ChromaPipeline(
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument. provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
prompt_attention_mask (torch.Tensor, *optional*):
Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
Chroma requires a single padding token remain unmasked. Please refer to
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
negative_prompt_attention_mask (torch.Tensor, *optional*):
Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -678,7 +740,9 @@ class ChromaPipeline( ...@@ -678,7 +740,9 @@ class ChromaPipeline(
width, width,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
) )
...@@ -704,13 +768,17 @@ class ChromaPipeline( ...@@ -704,13 +768,17 @@ class ChromaPipeline(
( (
prompt_embeds, prompt_embeds,
text_ids, text_ids,
prompt_attention_mask,
negative_prompt_embeds, negative_prompt_embeds,
negative_text_ids, negative_text_ids,
negative_prompt_attention_mask,
) = self.encode_prompt( ) = self.encode_prompt(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
do_classifier_free_guidance=self.do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
...@@ -730,6 +798,7 @@ class ChromaPipeline( ...@@ -730,6 +798,7 @@ class ChromaPipeline(
generator, generator,
latents, latents,
) )
# 5. Prepare timesteps # 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1] image_seq_len = latents.shape[1]
...@@ -740,6 +809,20 @@ class ChromaPipeline( ...@@ -740,6 +809,20 @@ class ChromaPipeline(
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15), self.scheduler.config.get("max_shift", 1.15),
) )
attention_mask = self._prepare_attention_mask(
batch_size=latents.shape[0],
sequence_length=image_seq_len,
dtype=latents.dtype,
attention_mask=prompt_attention_mask,
)
negative_attention_mask = self._prepare_attention_mask(
batch_size=latents.shape[0],
sequence_length=image_seq_len,
dtype=latents.dtype,
attention_mask=negative_prompt_attention_mask,
)
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
num_inference_steps, num_inference_steps,
...@@ -801,6 +884,7 @@ class ChromaPipeline( ...@@ -801,6 +884,7 @@ class ChromaPipeline(
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, txt_ids=text_ids,
img_ids=latent_image_ids, img_ids=latent_image_ids,
attention_mask=attention_mask,
joint_attention_kwargs=self.joint_attention_kwargs, joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -814,6 +898,7 @@ class ChromaPipeline( ...@@ -814,6 +898,7 @@ class ChromaPipeline(
encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids, txt_ids=negative_text_ids,
img_ids=latent_image_ids, img_ids=latent_image_ids,
attention_mask=negative_attention_mask,
joint_attention_kwargs=self.joint_attention_kwargs, joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
This diff is collapsed.
...@@ -272,6 +272,21 @@ class AuraFlowPipeline(metaclass=DummyObject): ...@@ -272,6 +272,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class ChromaImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ChromaPipeline(metaclass=DummyObject): class ChromaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
import random
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, torch_device
from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class ChromaImg2ImgPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
):
pipeline_class = ChromaImg2ImgPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
batch_params = frozenset(["prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = ChromaTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
axes_dims_rope=[4, 4, 8],
approximator_hidden_dim=32,
approximator_layers=1,
approximator_num_channels=16,
)
torch.manual_seed(0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"strength": 0.8,
"output_type": "np",
}
return inputs
def test_chroma_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
"Fusion of QKV projections shouldn't affect the outputs."
)
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Original outputs should match when fused QKV projections are disabled."
)
def test_chroma_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment