Unverified Commit dc6bd151 authored by josephrocca's avatar josephrocca Committed by GitHub
Browse files

Fix Chroma attention padding order and update docs to use `lodestones/Chroma1-HD` (#12508)



* [Fix] Move attention mask padding after T5 embedding

* [Fix] Move attention mask padding after T5 embedding

* Clean up whitespace in pipeline_chroma.py

Removed unnecessary blank lines for cleaner code.

* Fix

* Fix

* Update model to final Chroma1-HD checkpoint

* Update to Chroma1-HD

* Update model to Chroma1-HD

* Update model to Chroma1-HD

* Update Chroma model links to Chroma1-HD

* Add comment about padding/masking

* Fix checkpoint/repo references

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 500b9cf1
...@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. ...@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# ChromaTransformer2DModel # ChromaTransformer2DModel
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma) A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
## ChromaTransformer2DModel ## ChromaTransformer2DModel
......
...@@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License. ...@@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License.
Chroma is a text to image generation model based on Flux. Chroma is a text to image generation model based on Flux.
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma). Original model checkpoints for Chroma can be found here:
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
> [!TIP] > [!TIP]
> Chroma can use all the same optimizations as Flux. > Chroma can use all the same optimizations as Flux.
## Inference ## Inference
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
```python ```python
import torch import torch
from diffusers import ChromaPipeline from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16) pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
prompt = [ prompt = [
...@@ -63,10 +64,10 @@ Then run the following example ...@@ -63,10 +64,10 @@ Then run the following example
import torch import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline from diffusers import ChromaTransformer2DModel, ChromaPipeline
model_id = "lodestones/Chroma" model_id = "lodestones/Chroma1-HD"
dtype = torch.bfloat16 dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype) transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype) pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
......
...@@ -379,7 +379,7 @@ class ChromaTransformer2DModel( ...@@ -379,7 +379,7 @@ class ChromaTransformer2DModel(
""" """
The Transformer model introduced in Flux, modified for Chroma. The Transformer model introduced in Flux, modified for Chroma.
Reference: https://huggingface.co/lodestones/Chroma Reference: https://huggingface.co/lodestones/Chroma1-HD
Args: Args:
patch_size (`int`, defaults to `1`): patch_size (`int`, defaults to `1`):
......
...@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """ ...@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
>>> import torch >>> import torch
>>> from diffusers import ChromaPipeline >>> from diffusers import ChromaPipeline
>>> model_id = "lodestones/Chroma" >>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors" >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> pipe = ChromaPipeline.from_pretrained( >>> pipe = ChromaPipeline.from_pretrained(
... model_id, ... model_id,
...@@ -158,7 +158,7 @@ class ChromaPipeline( ...@@ -158,7 +158,7 @@ class ChromaPipeline(
r""" r"""
The Chroma pipeline for text-to-image generation. The Chroma pipeline for text-to-image generation.
Reference: https://huggingface.co/lodestones/Chroma/ Reference: https://huggingface.co/lodestones/Chroma1-HD/
Args: Args:
transformer ([`ChromaTransformer2DModel`]): transformer ([`ChromaTransformer2DModel`]):
...@@ -233,20 +233,23 @@ class ChromaPipeline( ...@@ -233,20 +233,23 @@ class ChromaPipeline(
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone() tokenizer_mask = text_inputs.attention_mask
# Chroma requires the attention mask to include one padding token tokenizer_mask_device = tokenizer_mask.to(device)
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
# unlike FLUX, Chroma uses the attention mask when generating the T5 embedding
prompt_embeds = self.text_encoder( prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
)[0] )[0]
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(device=device)
# for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
......
...@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """ ...@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
>>> import torch >>> import torch
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline >>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
>>> model_id = "lodestones/Chroma" >>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors" >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> pipe = ChromaImg2ImgPipeline.from_pretrained( >>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... model_id, ... model_id,
... transformer=transformer, ... transformer=transformer,
...@@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline( ...@@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline(
r""" r"""
The Chroma pipeline for image-to-image generation. The Chroma pipeline for image-to-image generation.
Reference: https://huggingface.co/lodestones/Chroma/ Reference: https://huggingface.co/lodestones/Chroma1-HD/
Args: Args:
transformer ([`ChromaTransformer2DModel`]): transformer ([`ChromaTransformer2DModel`]):
...@@ -247,20 +247,21 @@ class ChromaImg2ImgPipeline( ...@@ -247,20 +247,21 @@ class ChromaImg2ImgPipeline(
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone() tokenizer_mask = text_inputs.attention_mask
# Chroma requires the attention mask to include one padding token tokenizer_mask_device = tokenizer_mask.to(device)
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
prompt_embeds = self.text_encoder( prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
)[0] )[0]
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)
seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment