Unverified Commit 8adc6003 authored by Edna's avatar Edna Committed by GitHub
Browse files

Chroma Pipeline (#11698)



* working state from hameerabbasi and iddl

* working state form hameerabbasi and iddl (transformer)

* working state (normalization)

* working state (embeddings)

* add chroma loader

* add chroma to mappings

* add chroma to transformer init

* take out variant stuff

* get decently far in changing variant stuff

* add chroma init

* make chroma output class

* add chroma transformer to dummy tp

* add chroma to init

* add chroma to init

* fix single file

* update

* update

* add chroma to auto pipeline

* add chroma to pipeline init

* change to chroma transformer

* take out variant from blocks

* swap embedder location

* remove prompt_2

* work on swapping text encoders

* remove mask function

* dont modify mask (for now)

* wrap attn mask

* no attn mask (can't get it to work)

* remove pooled prompt embeds

* change to my own unpooled embeddeer

* fix load

* take pooled projections out of transformer

* ensure correct dtype for chroma embeddings

* update

* use dn6 attn mask + fix true_cfg_scale

* use chroma pipeline output

* use DN6 embeddings

* remove guidance

* remove guidance embed (pipeline)

* remove guidance from embeddings

* don't return length

* dont change dtype

* remove unused stuff, fix up docs

* add chroma autodoc

* add .md (oops)

* initial chroma docs

* undo don't change dtype

* undo arxiv change

unsure why that happened

* fix hf papers regression in more places

* Update docs/source/en/api/pipelines/chroma.md
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* do_cfg -> self.do_classifier_free_guidance

* Update docs/source/en/api/models/chroma_transformer.md
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Update chroma.md

* Move chroma layers into transformer

* Remove pruned AdaLayerNorms

* Add chroma fast tests

* (untested) batch cond and uncond

* Add # Copied from for shift

* Update # Copied from statements

* update norm imports

* Revert cond + uncond batching

* Add transformer tests

* move chroma test (oops)

* chroma init

* fix chroma pipeline fast tests

* Update src/diffusers/models/transformers/transformer_chroma.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Move Approximator and Embeddings

* Fix auto pipeline + make style, quality

* make style

* Apply style fixes

* switch to new input ids

* fix # Copied from error

* remove # Copied from on protected members

* try to fix import

* fix import

* make fix-copes

* revert style fix

* update chroma transformer params

* update chroma transformer approximator init params

* update to pad tokens

* fix batch inference

* Make more pipeline tests work

* Make most transformer tests work

* fix docs

* make style, make quality

* skip batch tests

* fix test skipping

* fix test skipping again

* fix for tests

* Fix all pipeline test

* update

* push local changes, fix docs

* add encoder test, remove pooled dim

* default proj dim

* fix tests

* fix equal size list input

* update

* push local changes, fix docs

* add encoder test, remove pooled dim

* default proj dim

* fix tests

* fix equal size list input

* Revert "fix equal size list input"

This reverts commit 3fe4ad67d58d83715bc238f8654f5e90bfc5653c.

* update

* update

* update

* update

* update

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 9f91305f
...@@ -283,6 +283,8 @@ ...@@ -283,6 +283,8 @@
title: AllegroTransformer3DModel title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d - local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel title: AuraFlowTransformer2DModel
- local: api/models/chroma_transformer
title: ChromaTransformer2DModel
- local: api/models/cogvideox_transformer3d - local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel title: CogVideoXTransformer3DModel
- local: api/models/cogview3plus_transformer2d - local: api/models/cogview3plus_transformer2d
...@@ -405,6 +407,8 @@ ...@@ -405,6 +407,8 @@
title: AutoPipeline title: AutoPipeline
- local: api/pipelines/blip_diffusion - local: api/pipelines/blip_diffusion
title: BLIP-Diffusion title: BLIP-Diffusion
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogvideox - local: api/pipelines/cogvideox
title: CogVideoX title: CogVideoX
- local: api/pipelines/cogview3 - local: api/pipelines/cogview3
......
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ChromaTransformer2DModel
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
## ChromaTransformer2DModel
[[autodoc]] ChromaTransformer2DModel
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Chroma
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Chroma is a text to image generation model based on Flux.
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
<Tip>
Chroma can use all the same optimizations as Flux.
</Tip>
## Inference (Single File)
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
The following example demonstrates how to run Chroma from a single file.
Then run the following example
```python
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
from transformers import T5EncoderModel
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
guidance_scale=4.0,
output_type="pil",
num_inference_steps=26,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("image.png")
```
## ChromaPipeline
[[autodoc]] ChromaPipeline
- all
- __call__
...@@ -159,6 +159,7 @@ else: ...@@ -159,6 +159,7 @@ else:
"AutoencoderTiny", "AutoencoderTiny",
"AutoModel", "AutoModel",
"CacheMixin", "CacheMixin",
"ChromaTransformer2DModel",
"CogVideoXTransformer3DModel", "CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel", "CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel", "CogView4Transformer2DModel",
...@@ -352,6 +353,7 @@ else: ...@@ -352,6 +353,7 @@ else:
"AuraFlowPipeline", "AuraFlowPipeline",
"BlipDiffusionControlNetPipeline", "BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline", "BlipDiffusionPipeline",
"ChromaPipeline",
"CLIPImageProjection", "CLIPImageProjection",
"CogVideoXFunControlPipeline", "CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline", "CogVideoXImageToVideoPipeline",
...@@ -770,6 +772,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -770,6 +772,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderTiny, AutoencoderTiny,
AutoModel, AutoModel,
CacheMixin, CacheMixin,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel, CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel, CogView3PlusTransformer2DModel,
CogView4Transformer2DModel, CogView4Transformer2DModel,
...@@ -942,6 +945,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -942,6 +945,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel, AudioLDM2UNet2DConditionModel,
AudioLDMPipeline, AudioLDMPipeline,
AuraFlowPipeline, AuraFlowPipeline,
ChromaPipeline,
CLIPImageProjection, CLIPImageProjection,
CogVideoXFunControlPipeline, CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline, CogVideoXImageToVideoPipeline,
......
...@@ -60,6 +60,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { ...@@ -60,6 +60,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights, "WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
} }
......
...@@ -29,6 +29,7 @@ from .single_file_utils import ( ...@@ -29,6 +29,7 @@ from .single_file_utils import (
convert_animatediff_checkpoint_to_diffusers, convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers, convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers, convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint, convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers, convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers, convert_hidream_transformer_to_diffusers,
...@@ -97,6 +98,10 @@ SINGLE_FILE_LOADABLE_CLASSES = { ...@@ -97,6 +98,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer", "default_subfolder": "transformer",
}, },
"ChromaTransformer2DModel": {
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": { "LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer", "default_subfolder": "transformer",
......
...@@ -3310,3 +3310,172 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs): ...@@ -3310,3 +3310,172 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
return checkpoint return checkpoint
def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
num_guidance_layers = (
list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
)
mlp_ratio = 4.0
inner_dim = 3072
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
# guidance
converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.bias"
)
converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.weight"
)
converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.bias"
)
converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.weight"
)
for i in range(num_guidance_layers):
block_prefix = f"distilled_guidance_layer.layers.{i}."
converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.weight"
)
converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.weight"
)
converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
f"distilled_guidance_layer.norms.{i}.scale"
)
# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
# x_embedder
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
return converted_state_dict
...@@ -74,6 +74,7 @@ if is_torch_available(): ...@@ -74,6 +74,7 @@ if is_torch_available():
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"] _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
...@@ -151,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -151,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import ( from .transformers import (
AllegroTransformer3DModel, AllegroTransformer3DModel,
AuraFlowTransformer2DModel, AuraFlowTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel, CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel, CogView3PlusTransformer2DModel,
CogView4Transformer2DModel, CogView4Transformer2DModel,
......
...@@ -31,7 +31,7 @@ def get_timestep_embedding( ...@@ -31,7 +31,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1, downscale_freq_shift: float = 1,
scale: float = 1, scale: float = 1,
max_period: int = 10000, max_period: int = 10000,
): ) -> torch.Tensor:
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
...@@ -1325,7 +1325,7 @@ class Timesteps(nn.Module): ...@@ -1325,7 +1325,7 @@ class Timesteps(nn.Module):
self.downscale_freq_shift = downscale_freq_shift self.downscale_freq_shift = downscale_freq_shift
self.scale = scale self.scale = scale
def forward(self, timesteps): def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding( t_emb = get_timestep_embedding(
timesteps, timesteps,
self.num_channels, self.num_channels,
......
...@@ -17,6 +17,7 @@ if is_torch_available(): ...@@ -17,6 +17,7 @@ if is_torch_available():
from .t5_film_transformer import T5FilmDecoder from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel from .transformer_allegro import AllegroTransformer3DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_cosmos import CosmosTransformer3DModel from .transformer_cosmos import CosmosTransformer3DModel
......
This diff is collapsed.
...@@ -148,6 +148,7 @@ else: ...@@ -148,6 +148,7 @@ else:
"AudioLDM2UNet2DConditionModel", "AudioLDM2UNet2DConditionModel",
] ]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["chroma"] = ["ChromaPipeline"]
_import_structure["cogvideo"] = [ _import_structure["cogvideo"] = [
"CogVideoXPipeline", "CogVideoXPipeline",
"CogVideoXImageToVideoPipeline", "CogVideoXImageToVideoPipeline",
...@@ -536,6 +537,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -536,6 +537,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
) )
from .aura_flow import AuraFlowPipeline from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline from .blip_diffusion import BlipDiffusionPipeline
from .chroma import ChromaPipeline
from .cogvideo import ( from .cogvideo import (
CogVideoXFunControlPipeline, CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline, CogVideoXImageToVideoPipeline,
......
...@@ -21,6 +21,7 @@ from ..configuration_utils import ConfigMixin ...@@ -21,6 +21,7 @@ from ..configuration_utils import ConfigMixin
from ..models.controlnets import ControlNetUnionModel from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline from .aura_flow import AuraFlowPipeline
from .chroma import ChromaPipeline
from .cogview3 import CogView3PlusPipeline from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .controlnet import ( from .controlnet import (
...@@ -143,6 +144,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -143,6 +144,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux-controlnet", FluxControlNetPipeline), ("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaPipeline), ("lumina", LuminaPipeline),
("lumina2", Lumina2Pipeline), ("lumina2", Lumina2Pipeline),
("chroma", ChromaPipeline),
("cogview3", CogView3PlusPipeline), ("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline), ("cogview4", CogView4Pipeline),
("cogview4-control", CogView4ControlPipeline), ("cogview4-control", CogView4ControlPipeline),
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["ChromaPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_chroma import ChromaPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class ChromaPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
...@@ -325,6 +325,21 @@ class CacheMixin(metaclass=DummyObject): ...@@ -325,6 +325,21 @@ class CacheMixin(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class ChromaTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CogVideoXTransformer3DModel(metaclass=DummyObject): class CogVideoXTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -272,6 +272,21 @@ class AuraFlowPipeline(metaclass=DummyObject): ...@@ -272,6 +272,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class ChromaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CLIPImageProjection(metaclass=DummyObject): class CLIPImageProjection(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import ChromaTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
def create_chroma_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = ChromaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
}
@property
def input_shape(self):
return (16, 4)
@property
def output_shape(self):
return (16, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"axes_dims_rope": [4, 4, 8],
"approximator_num_channels": 8,
"approximator_hidden_dim": 16,
"approximator_layers": 1,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output_1 = model(**inputs_dict).to_tuple()[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
inputs_dict["txt_ids"] = text_ids_3d
inputs_dict["img_ids"] = image_ids_3d
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"ChromaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class ChromaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = ChromaTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
class ChromaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = ChromaTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
...@@ -57,7 +57,9 @@ def create_flux_ip_adapter_state_dict(model): ...@@ -57,7 +57,9 @@ def create_flux_ip_adapter_state_dict(model):
image_projection = ImageProjection( image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"], cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"], image_embed_dim=(
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
),
num_image_text_embeds=4, num_image_text_embeds=4,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment